Port asyncio to Python 2, trollius/ directory

This commit is contained in:
Victor Stinner
2015-07-07 21:17:17 +02:00
parent a01f3f4993
commit a4749501b4
22 changed files with 1321 additions and 787 deletions

View File

@@ -1,4 +1,4 @@
"""The asyncio package, tracking PEP 3156."""
"""The trollius package, tracking PEP 3156."""
import sys
@@ -24,6 +24,7 @@ from .events import *
from .futures import *
from .locks import *
from .protocols import *
from .py33_exceptions import *
from .queues import *
from .streams import *
from .subprocess import *
@@ -33,6 +34,7 @@ from .transports import *
__all__ = (base_events.__all__ +
coroutines.__all__ +
events.__all__ +
py33_exceptions.__all__ +
futures.__all__ +
locks.__all__ +
protocols.__all__ +
@@ -48,3 +50,10 @@ if sys.platform == 'win32': # pragma: no cover
else:
from .unix_events import * # pragma: no cover
__all__ += unix_events.__all__
try:
from .py3_ssl import *
__all__ += py3_ssl.__all__
except ImportError:
# SSL support is optionnal
pass

View File

@@ -15,25 +15,34 @@ to modify the meaning of the API call itself.
import collections
import concurrent.futures
import heapq
import inspect
import logging
import os
import socket
import subprocess
import threading
import time
import traceback
import sys
import warnings
import traceback
try:
from collections import OrderedDict
except ImportError:
# Python 2.6: use ordereddict backport
from ordereddict import OrderedDict
try:
from threading import get_ident as _get_thread_ident
except ImportError:
# Python 2
from threading import _get_ident as _get_thread_ident
from . import compat
from . import coroutines
from . import events
from . import futures
from . import tasks
from .coroutines import coroutine
from .coroutines import coroutine, From, Return
from .executor import get_default_executor
from .log import logger
from .time_monotonic import time_monotonic, time_monotonic_resolution
__all__ = ['BaseEventLoop']
@@ -171,10 +180,10 @@ class Server(events.AbstractServer):
@coroutine
def wait_closed(self):
if self.sockets is None or self._waiters is None:
return
raise Return()
waiter = futures.Future(loop=self._loop)
self._waiters.append(waiter)
yield from waiter
yield From(waiter)
class BaseEventLoop(events.AbstractEventLoop):
@@ -191,8 +200,7 @@ class BaseEventLoop(events.AbstractEventLoop):
self._thread_id = None
self._clock_resolution = time.get_clock_info('monotonic').resolution
self._exception_handler = None
self.set_debug((not sys.flags.ignore_environment
and bool(os.environ.get('PYTHONASYNCIODEBUG'))))
self.set_debug(bool(os.environ.get('TROLLIUSDEBUG')))
# In debug mode, if the execution of a callback or a step of a task
# exceed this duration in seconds, the slow callback/task is logged.
self.slow_callback_duration = 0.1
@@ -237,13 +245,13 @@ class BaseEventLoop(events.AbstractEventLoop):
"""Return a task factory, or None if the default one is in use."""
return self._task_factory
def _make_socket_transport(self, sock, protocol, waiter=None, *,
def _make_socket_transport(self, sock, protocol, waiter=None,
extra=None, server=None):
"""Create socket transport."""
raise NotImplementedError
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
server_side=False, server_hostname=None,
extra=None, server=None):
"""Create SSL transport."""
raise NotImplementedError
@@ -506,7 +514,7 @@ class BaseEventLoop(events.AbstractEventLoop):
if executor is None:
executor = self._default_executor
if executor is None:
executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS)
executor = get_default_executor()
self._default_executor = executor
return futures.wrap_future(executor.submit(func, *args), loop=self)
@@ -538,7 +546,7 @@ class BaseEventLoop(events.AbstractEventLoop):
logger.debug(msg)
return addrinfo
def getaddrinfo(self, host, port, *,
def getaddrinfo(self, host, port,
family=0, type=0, proto=0, flags=0):
if self._debug:
return self.run_in_executor(None, self._getaddrinfo_debug,
@@ -551,7 +559,7 @@ class BaseEventLoop(events.AbstractEventLoop):
return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags)
@coroutine
def create_connection(self, protocol_factory, host=None, port=None, *,
def create_connection(self, protocol_factory, host=None, port=None,
ssl=None, family=0, proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None):
"""Connect to a TCP server.
@@ -601,15 +609,15 @@ class BaseEventLoop(events.AbstractEventLoop):
else:
f2 = None
yield from tasks.wait(fs, loop=self)
yield From(tasks.wait(fs, loop=self))
infos = f1.result()
if not infos:
raise OSError('getaddrinfo() returned empty list')
raise socket.error('getaddrinfo() returned empty list')
if f2 is not None:
laddr_infos = f2.result()
if not laddr_infos:
raise OSError('getaddrinfo() returned empty list')
raise socket.error('getaddrinfo() returned empty list')
exceptions = []
for family, type, proto, cname, address in infos:
@@ -621,11 +629,11 @@ class BaseEventLoop(events.AbstractEventLoop):
try:
sock.bind(laddr)
break
except OSError as exc:
exc = OSError(
except socket.error as exc:
exc = socket.error(
exc.errno, 'error while '
'attempting to bind on address '
'{!r}: {}'.format(
'{0!r}: {1}'.format(
laddr, exc.strerror.lower()))
exceptions.append(exc)
else:
@@ -634,8 +642,8 @@ class BaseEventLoop(events.AbstractEventLoop):
continue
if self._debug:
logger.debug("connect %r to %r", sock, address)
yield from self.sock_connect(sock, address)
except OSError as exc:
yield From(self.sock_connect(sock, address))
except socket.error as exc:
if sock is not None:
sock.close()
exceptions.append(exc)
@@ -655,7 +663,7 @@ class BaseEventLoop(events.AbstractEventLoop):
raise exceptions[0]
# Raise a combined exception so the user can see all
# the various error messages.
raise OSError('Multiple exceptions: {}'.format(
raise socket.error('Multiple exceptions: {0}'.format(
', '.join(str(exc) for exc in exceptions)))
elif sock is None:
@@ -664,15 +672,15 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setblocking(False)
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
transport, protocol = yield From(self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname))
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
sock = transport.get_extra_info('socket')
logger.debug("%r connected to %s:%r: (%r, %r)",
sock, host, port, transport, protocol)
return transport, protocol
raise Return(transport, protocol)
@coroutine
def _create_connection_transport(self, sock, protocol_factory, ssl,
@@ -688,12 +696,12 @@ class BaseEventLoop(events.AbstractEventLoop):
transport = self._make_socket_transport(sock, protocol, waiter)
try:
yield from waiter
yield From(waiter)
except:
transport.close()
raise
return transport, protocol
raise Return(transport, protocol)
@coroutine
def create_datagram_endpoint(self, protocol_factory,
@@ -706,17 +714,17 @@ class BaseEventLoop(events.AbstractEventLoop):
addr_pairs_info = (((family, proto), (None, None)),)
else:
# join address by (family, protocol)
addr_infos = collections.OrderedDict()
addr_infos = OrderedDict()
for idx, addr in ((0, local_addr), (1, remote_addr)):
if addr is not None:
assert isinstance(addr, tuple) and len(addr) == 2, (
'2-tuple is expected')
infos = yield from self.getaddrinfo(
infos = yield From(self.getaddrinfo(
*addr, family=family, type=socket.SOCK_DGRAM,
proto=proto, flags=flags)
proto=proto, flags=flags))
if not infos:
raise OSError('getaddrinfo() returned empty list')
raise socket.error('getaddrinfo() returned empty list')
for fam, _, pro, _, address in infos:
key = (fam, pro)
@@ -748,9 +756,9 @@ class BaseEventLoop(events.AbstractEventLoop):
if local_addr:
sock.bind(local_address)
if remote_addr:
yield from self.sock_connect(sock, remote_address)
yield From(self.sock_connect(sock, remote_address))
r_addr = remote_address
except OSError as exc:
except socket.error as exc:
if sock is not None:
sock.close()
exceptions.append(exc)
@@ -778,16 +786,15 @@ class BaseEventLoop(events.AbstractEventLoop):
remote_addr, transport, protocol)
try:
yield from waiter
yield From(waiter)
except:
transport.close()
raise
return transport, protocol
raise Return(transport, protocol)
@coroutine
def create_server(self, protocol_factory, host=None, port=None,
*,
family=socket.AF_UNSPEC,
flags=socket.AI_PASSIVE,
sock=None,
@@ -814,11 +821,11 @@ class BaseEventLoop(events.AbstractEventLoop):
if host == '':
host = None
infos = yield from self.getaddrinfo(
infos = yield From(self.getaddrinfo(
host, port, family=family,
type=socket.SOCK_STREAM, proto=0, flags=flags)
type=socket.SOCK_STREAM, proto=0, flags=flags))
if not infos:
raise OSError('getaddrinfo() returned empty list')
raise socket.error('getaddrinfo() returned empty list')
completed = False
try:
@@ -846,10 +853,11 @@ class BaseEventLoop(events.AbstractEventLoop):
True)
try:
sock.bind(sa)
except OSError as err:
raise OSError(err.errno, 'error while attempting '
'to bind on address %r: %s'
% (sa, err.strerror.lower()))
except socket.error as err:
raise socket.error(err.errno,
'error while attempting '
'to bind on address %r: %s'
% (sa, err.strerror.lower()))
completed = True
finally:
if not completed:
@@ -867,7 +875,7 @@ class BaseEventLoop(events.AbstractEventLoop):
self._start_serving(protocol_factory, sock, ssl, server)
if self._debug:
logger.info("%r is serving", server)
return server
raise Return(server)
@coroutine
def connect_read_pipe(self, protocol_factory, pipe):
@@ -876,7 +884,7 @@ class BaseEventLoop(events.AbstractEventLoop):
transport = self._make_read_pipe_transport(pipe, protocol, waiter)
try:
yield from waiter
yield From(waiter)
except:
transport.close()
raise
@@ -884,7 +892,7 @@ class BaseEventLoop(events.AbstractEventLoop):
if self._debug:
logger.debug('Read pipe %r connected: (%r, %r)',
pipe.fileno(), transport, protocol)
return transport, protocol
raise Return(transport, protocol)
@coroutine
def connect_write_pipe(self, protocol_factory, pipe):
@@ -893,7 +901,7 @@ class BaseEventLoop(events.AbstractEventLoop):
transport = self._make_write_pipe_transport(pipe, protocol, waiter)
try:
yield from waiter
yield From(waiter)
except:
transport.close()
raise
@@ -901,7 +909,7 @@ class BaseEventLoop(events.AbstractEventLoop):
if self._debug:
logger.debug('Write pipe %r connected: (%r, %r)',
pipe.fileno(), transport, protocol)
return transport, protocol
raise Return(transport, protocol)
def _log_subprocess(self, msg, stdin, stdout, stderr):
info = [msg]
@@ -917,11 +925,11 @@ class BaseEventLoop(events.AbstractEventLoop):
logger.debug(' '.join(info))
@coroutine
def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE,
def subprocess_shell(self, protocol_factory, cmd, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
universal_newlines=False, shell=True, bufsize=0,
**kwargs):
if not isinstance(cmd, (bytes, str)):
if not isinstance(cmd, compat.string_types):
raise ValueError("cmd must be a string")
if universal_newlines:
raise ValueError("universal_newlines must be False")
@@ -935,17 +943,20 @@ class BaseEventLoop(events.AbstractEventLoop):
# (password) and may be too long
debug_log = 'run shell command %r' % cmd
self._log_subprocess(debug_log, stdin, stdout, stderr)
transport = yield from self._make_subprocess_transport(
protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs)
transport = yield From(self._make_subprocess_transport(
protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs))
if self._debug:
logger.info('%s: %r' % (debug_log, transport))
return transport, protocol
raise Return(transport, protocol)
@coroutine
def subprocess_exec(self, protocol_factory, program, *args,
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=False,
shell=False, bufsize=0, **kwargs):
def subprocess_exec(self, protocol_factory, program, *args, **kwargs):
stdin = kwargs.pop('stdin', subprocess.PIPE)
stdout = kwargs.pop('stdout', subprocess.PIPE)
stderr = kwargs.pop('stderr', subprocess.PIPE)
universal_newlines = kwargs.pop('universal_newlines', False)
shell = kwargs.pop('shell', False)
bufsize = kwargs.pop('bufsize', 0)
if universal_newlines:
raise ValueError("universal_newlines must be False")
if shell:
@@ -954,7 +965,7 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError("bufsize must be 0")
popen_args = (program,) + args
for arg in popen_args:
if not isinstance(arg, (str, bytes)):
if not isinstance(arg, compat.string_types ):
raise TypeError("program arguments must be "
"a bytes or text string, not %s"
% type(arg).__name__)
@@ -964,12 +975,12 @@ class BaseEventLoop(events.AbstractEventLoop):
# (password) and may be too long
debug_log = 'execute program %r' % program
self._log_subprocess(debug_log, stdin, stdout, stderr)
transport = yield from self._make_subprocess_transport(
transport = yield From(self._make_subprocess_transport(
protocol, popen_args, False, stdin, stdout, stderr,
bufsize, **kwargs)
bufsize, **kwargs))
if self._debug:
logger.info('%s: %r' % (debug_log, transport))
return transport, protocol
raise Return(transport, protocol)
def set_exception_handler(self, handler):
"""Set handler as the new event loop exception handler.
@@ -985,7 +996,7 @@ class BaseEventLoop(events.AbstractEventLoop):
"""
if handler is not None and not callable(handler):
raise TypeError('A callable object or None is expected, '
'got {!r}'.format(handler))
'got {0!r}'.format(handler))
self._exception_handler = handler
def default_exception_handler(self, context):
@@ -1004,7 +1015,15 @@ class BaseEventLoop(events.AbstractEventLoop):
exception = context.get('exception')
if exception is not None:
exc_info = (type(exception), exception, exception.__traceback__)
if hasattr(exception, '__traceback__'):
# Python 3
tb = exception.__traceback__
else:
# call_exception_handler() is usually called indirectly
# from an except block. If it's not the case, the traceback
# is undefined...
tb = sys.exc_info()[2]
exc_info = (type(exception), exception, tb)
else:
exc_info = False
@@ -1015,7 +1034,7 @@ class BaseEventLoop(events.AbstractEventLoop):
log_lines = [message]
for key in sorted(context):
if key in {'message', 'exception'}:
if key in ('message', 'exception'):
continue
value = context[key]
if key == 'source_traceback':
@@ -1028,7 +1047,7 @@ class BaseEventLoop(events.AbstractEventLoop):
value += tb.rstrip()
else:
value = repr(value)
log_lines.append('{}: {}'.format(key, value))
log_lines.append('{0}: {1}'.format(key, value))
logger.error('\n'.join(log_lines), exc_info=exc_info)
@@ -1108,7 +1127,7 @@ class BaseEventLoop(events.AbstractEventLoop):
sched_count = len(self._scheduled)
if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and
self._timer_cancelled_count / sched_count >
float(self._timer_cancelled_count) / sched_count >
_MIN_CANCELLED_TIMER_HANDLES_FRACTION):
# Remove delayed calls that were cancelled if their number
# is too high

View File

@@ -6,7 +6,7 @@ import warnings
from . import futures
from . import protocols
from . import transports
from .coroutines import coroutine
from .coroutines import coroutine, From
from .log import logger
@@ -15,7 +15,7 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
def __init__(self, loop, protocol, args, shell,
stdin, stdout, stderr, bufsize,
waiter=None, extra=None, **kwargs):
super().__init__(extra)
super(BaseSubprocessTransport, self).__init__(extra)
self._closed = False
self._protocol = protocol
self._loop = loop
@@ -221,7 +221,8 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
waiter = futures.Future(loop=self._loop)
self._exit_waiters.append(waiter)
return (yield from waiter)
returncode = yield From(waiter)
return returncode
def _try_finish(self):
assert not self._finished

View File

@@ -9,6 +9,7 @@ import sys
import traceback
import types
from . import compat
from . import events
from . import futures
from .log import logger
@@ -18,7 +19,7 @@ _PY35 = sys.version_info >= (3, 5)
# Opcode of "yield from" instruction
_YIELD_FROM = opcode.opmap['YIELD_FROM']
_YIELD_FROM = opcode.opmap.get('YIELD_FROM', None)
# If you set _DEBUG to true, @coroutine will wrap the resulting
# generator objects in a CoroWrapper instance (defined below). That
@@ -29,8 +30,7 @@ _YIELD_FROM = opcode.opmap['YIELD_FROM']
# before you define your coroutines. A downside of using this feature
# is that tracebacks show entries for the CoroWrapper.__next__ method
# when _DEBUG is true.
_DEBUG = (not sys.flags.ignore_environment
and bool(os.environ.get('PYTHONASYNCIODEBUG')))
_DEBUG = bool(os.environ.get('TROLLIUSDEBUG'))
try:
@@ -74,6 +74,53 @@ _YIELD_FROM_BUG = has_yield_from_bug()
del has_yield_from_bug
if compat.PY33:
# Don't use the Return class on Python 3.3 and later to support asyncio
# coroutines (to avoid the warning emited in Return destructor).
#
# The problem is that Return inherits from StopIteration. "yield from
# trollius_coroutine". Task._step() does not receive the Return exception,
# because "yield from" handles it internally. So it's not possible to set
# the raised attribute to True to avoid the warning in Return destructor.
def Return(*args):
if not args:
value = None
elif len(args) == 1:
value = args[0]
else:
value = args
return StopIteration(value)
else:
class Return(StopIteration):
def __init__(self, *args):
StopIteration.__init__(self)
if not args:
self.value = None
elif len(args) == 1:
self.value = args[0]
else:
self.value = args
self.raised = False
if _DEBUG:
frame = sys._getframe(1)
self._source_traceback = traceback.extract_stack(frame)
# explicitly clear the reference to avoid reference cycles
frame = None
else:
self._source_traceback = None
def __del__(self):
if self.raised:
return
fmt = 'Return(%r) used without raise'
if self._source_traceback:
fmt += '\nReturn created at (most recent call last):\n'
tb = ''.join(traceback.format_list(self._source_traceback))
fmt += tb.rstrip()
logger.error(fmt, self.value)
def debug_wrapper(gen):
# This function is called from 'sys.set_coroutine_wrapper'.
# We only wrap here coroutines defined via 'async def' syntax.
@@ -104,7 +151,8 @@ class CoroWrapper:
return self
def __next__(self):
return self.gen.send(None)
return next(self.gen)
next = __next__
if _YIELD_FROM_BUG:
# For for CPython issue #21209: using "yield from" and a custom
@@ -180,6 +228,56 @@ class CoroWrapper:
msg += tb.rstrip()
logger.error(msg)
if not compat.PY34:
# Backport functools.update_wrapper() from Python 3.4:
# - Python 2.7 fails if assigned attributes don't exist
# - Python 2.7 and 3.1 don't set the __wrapped__ attribute
# - Python 3.2 and 3.3 set __wrapped__ before updating __dict__
def _update_wrapper(wrapper,
wrapped,
assigned = functools.WRAPPER_ASSIGNMENTS,
updated = functools.WRAPPER_UPDATES):
"""Update a wrapper function to look like the wrapped function
wrapper is the function to be updated
wrapped is the original function
assigned is a tuple naming the attributes assigned directly
from the wrapped function to the wrapper function (defaults to
functools.WRAPPER_ASSIGNMENTS)
updated is a tuple naming the attributes of the wrapper that
are updated with the corresponding attribute from the wrapped
function (defaults to functools.WRAPPER_UPDATES)
"""
for attr in assigned:
try:
value = getattr(wrapped, attr)
except AttributeError:
pass
else:
setattr(wrapper, attr, value)
for attr in updated:
getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
# Issue #17482: set __wrapped__ last so we don't inadvertently copy it
# from the wrapped function when updating __dict__
wrapper.__wrapped__ = wrapped
# Return the wrapper so this can be used as a decorator via partial()
return wrapper
def _wraps(wrapped,
assigned = functools.WRAPPER_ASSIGNMENTS,
updated = functools.WRAPPER_UPDATES):
"""Decorator factory to apply update_wrapper() to a wrapper function
Returns a decorator that invokes update_wrapper() with the decorated
function as the wrapper argument and the arguments to wraps() as the
remaining arguments. Default arguments are as for update_wrapper().
This is a convenience function to simplify applying partial() to
update_wrapper().
"""
return functools.partial(_update_wrapper, wrapped=wrapped,
assigned=assigned, updated=updated)
else:
_wraps = functools.wraps
def coroutine(func):
"""Decorator to mark coroutines.
@@ -197,7 +295,7 @@ def coroutine(func):
if inspect.isgeneratorfunction(func):
coro = func
else:
@functools.wraps(func)
@_wraps(func)
def coro(*args, **kw):
res = func(*args, **kw)
if isinstance(res, futures.Future) or inspect.isgenerator(res):
@@ -220,7 +318,7 @@ def coroutine(func):
else:
wrapper = _types_coroutine(coro)
else:
@functools.wraps(func)
@_wraps(func)
def wrapper(*args, **kwds):
w = CoroWrapper(coro(*args, **kwds), func=func)
if w._source_traceback:
@@ -246,7 +344,13 @@ def iscoroutinefunction(func):
_COROUTINE_TYPES = (types.GeneratorType, CoroWrapper)
if _CoroutineABC is not None:
_COROUTINE_TYPES += (_CoroutineABC,)
if events.asyncio is not None:
# Accept also asyncio CoroWrapper for interoperability
if hasattr(events.asyncio, 'coroutines'):
_COROUTINE_TYPES += (events.asyncio.coroutines.CoroWrapper,)
else:
# old Tulip/Python versions
_COROUTINE_TYPES += (events.asyncio.tasks.CoroWrapper,)
def iscoroutine(obj):
"""Return True if obj is a coroutine object."""
@@ -299,3 +403,19 @@ def _format_coroutine(coro):
% (coro_name, filename, lineno))
return coro_repr
class FromWrapper(object):
__slots__ = ('obj',)
def __init__(self, obj):
if isinstance(obj, FromWrapper):
obj = obj.obj
assert not isinstance(obj, FromWrapper)
self.obj = obj
def From(obj):
if not _DEBUG:
return obj
else:
return FromWrapper(obj)

View File

@@ -10,12 +10,23 @@ __all__ = ['AbstractEventLoopPolicy',
import functools
import inspect
import reprlib
import socket
import subprocess
import sys
import threading
import traceback
try:
import reprlib # Python 3
except ImportError:
import repr as reprlib # Python 2
from trollius import compat
try:
import asyncio
except (ImportError, SyntaxError):
# ignore SyntaxError for convenience: ignore SyntaxError caused by "yield
# from" if asyncio module is in the Python path
asyncio = None
_PY34 = sys.version_info >= (3, 4)
@@ -75,7 +86,7 @@ def _format_callback_source(func, args):
return func_repr
class Handle:
class Handle(object):
"""Object returned by callback registration methods."""
__slots__ = ('_callback', '_args', '_cancelled', '_loop',
@@ -145,14 +156,14 @@ class TimerHandle(Handle):
def __init__(self, when, callback, args, loop):
assert when is not None
super().__init__(callback, args, loop)
super(TimerHandle, self).__init__(callback, args, loop)
if self._source_traceback:
del self._source_traceback[-1]
self._when = when
self._scheduled = False
def _repr_info(self):
info = super()._repr_info()
info = super(TimerHandle, self)._repr_info()
pos = 2 if self._cancelled else 1
info.insert(pos, 'when=%s' % self._when)
return info
@@ -191,10 +202,10 @@ class TimerHandle(Handle):
def cancel(self):
if not self._cancelled:
self._loop._timer_handle_cancelled(self)
super().cancel()
super(TimerHandle, self).cancel()
class AbstractServer:
class AbstractServer(object):
"""Abstract server returned by create_server()."""
def close(self):
@@ -206,298 +217,303 @@ class AbstractServer:
return NotImplemented
class AbstractEventLoop:
"""Abstract event loop."""
if asyncio is not None:
# Reuse asyncio classes so asyncio.set_event_loop() and
# asyncio.set_event_loop_policy() accept Trollius event loop and trollius
# event loop policy
AbstractEventLoop = asyncio.AbstractEventLoop
AbstractEventLoopPolicy = asyncio.AbstractEventLoopPolicy
else:
class AbstractEventLoop(object):
"""Abstract event loop."""
# Running and stopping the event loop.
# Running and stopping the event loop.
def run_forever(self):
"""Run the event loop until stop() is called."""
raise NotImplementedError
def run_forever(self):
"""Run the event loop until stop() is called."""
raise NotImplementedError
def run_until_complete(self, future):
"""Run the event loop until a Future is done.
def run_until_complete(self, future):
"""Run the event loop until a Future is done.
Return the Future's result, or raise its exception.
"""
raise NotImplementedError
Return the Future's result, or raise its exception.
"""
raise NotImplementedError
def stop(self):
"""Stop the event loop as soon as reasonable.
def stop(self):
"""Stop the event loop as soon as reasonable.
Exactly how soon that is may depend on the implementation, but
no more I/O callbacks should be scheduled.
"""
raise NotImplementedError
Exactly how soon that is may depend on the implementation, but
no more I/O callbacks should be scheduled.
"""
raise NotImplementedError
def is_running(self):
"""Return whether the event loop is currently running."""
raise NotImplementedError
def is_running(self):
"""Return whether the event loop is currently running."""
raise NotImplementedError
def is_closed(self):
"""Returns True if the event loop was closed."""
raise NotImplementedError
def is_closed(self):
"""Returns True if the event loop was closed."""
raise NotImplementedError
def close(self):
"""Close the loop.
def close(self):
"""Close the loop.
The loop should not be running.
The loop should not be running.
This is idempotent and irreversible.
This is idempotent and irreversible.
No other methods should be called after this one.
"""
raise NotImplementedError
No other methods should be called after this one.
"""
raise NotImplementedError
# Methods scheduling callbacks. All these return Handles.
# Methods scheduling callbacks. All these return Handles.
def _timer_handle_cancelled(self, handle):
"""Notification that a TimerHandle has been cancelled."""
raise NotImplementedError
def _timer_handle_cancelled(self, handle):
"""Notification that a TimerHandle has been cancelled."""
raise NotImplementedError
def call_soon(self, callback, *args):
return self.call_later(0, callback, *args)
def call_soon(self, callback, *args):
return self.call_later(0, callback, *args)
def call_later(self, delay, callback, *args):
raise NotImplementedError
def call_later(self, delay, callback, *args):
raise NotImplementedError
def call_at(self, when, callback, *args):
raise NotImplementedError
def call_at(self, when, callback, *args):
raise NotImplementedError
def time(self):
raise NotImplementedError
def time(self):
raise NotImplementedError
# Method scheduling a coroutine object: create a task.
# Method scheduling a coroutine object: create a task.
def create_task(self, coro):
raise NotImplementedError
def create_task(self, coro):
raise NotImplementedError
# Methods for interacting with threads.
# Methods for interacting with threads.
def call_soon_threadsafe(self, callback, *args):
raise NotImplementedError
def call_soon_threadsafe(self, callback, *args):
raise NotImplementedError
def run_in_executor(self, executor, func, *args):
raise NotImplementedError
def run_in_executor(self, executor, func, *args):
raise NotImplementedError
def set_default_executor(self, executor):
raise NotImplementedError
def set_default_executor(self, executor):
raise NotImplementedError
# Network I/O methods returning Futures.
# Network I/O methods returning Futures.
def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0):
raise NotImplementedError
def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0):
raise NotImplementedError
def getnameinfo(self, sockaddr, flags=0):
raise NotImplementedError
def getnameinfo(self, sockaddr, flags=0):
raise NotImplementedError
def create_connection(self, protocol_factory, host=None, port=None, *,
ssl=None, family=0, proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None):
raise NotImplementedError
def create_connection(self, protocol_factory, host=None, port=None,
ssl=None, family=0, proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None):
raise NotImplementedError
def create_server(self, protocol_factory, host=None, port=None, *,
family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE,
sock=None, backlog=100, ssl=None, reuse_address=None):
"""A coroutine which creates a TCP server bound to host and port.
def create_server(self, protocol_factory, host=None, port=None,
family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE,
sock=None, backlog=100, ssl=None, reuse_address=None):
"""A coroutine which creates a TCP server bound to host and port.
The return value is a Server object which can be used to stop
the service.
The return value is a Server object which can be used to stop
the service.
If host is an empty string or None all interfaces are assumed
and a list of multiple sockets will be returned (most likely
one for IPv4 and another one for IPv6).
If host is an empty string or None all interfaces are assumed
and a list of multiple sockets will be returned (most likely
one for IPv4 and another one for IPv6).
family can be set to either AF_INET or AF_INET6 to force the
socket to use IPv4 or IPv6. If not set it will be determined
from host (defaults to AF_UNSPEC).
family can be set to either AF_INET or AF_INET6 to force the
socket to use IPv4 or IPv6. If not set it will be determined
from host (defaults to AF_UNSPEC).
flags is a bitmask for getaddrinfo().
flags is a bitmask for getaddrinfo().
sock can optionally be specified in order to use a preexisting
socket object.
sock can optionally be specified in order to use a preexisting
socket object.
backlog is the maximum number of queued connections passed to
listen() (defaults to 100).
backlog is the maximum number of queued connections passed to
listen() (defaults to 100).
ssl can be set to an SSLContext to enable SSL over the
accepted connections.
ssl can be set to an SSLContext to enable SSL over the
accepted connections.
reuse_address tells the kernel to reuse a local socket in
TIME_WAIT state, without waiting for its natural timeout to
expire. If not specified will automatically be set to True on
UNIX.
"""
raise NotImplementedError
reuse_address tells the kernel to reuse a local socket in
TIME_WAIT state, without waiting for its natural timeout to
expire. If not specified will automatically be set to True on
UNIX.
"""
raise NotImplementedError
def create_unix_connection(self, protocol_factory, path, *,
ssl=None, sock=None,
server_hostname=None):
raise NotImplementedError
def create_unix_connection(self, protocol_factory, path,
ssl=None, sock=None,
server_hostname=None):
raise NotImplementedError
def create_unix_server(self, protocol_factory, path, *,
sock=None, backlog=100, ssl=None):
"""A coroutine which creates a UNIX Domain Socket server.
def create_unix_server(self, protocol_factory, path,
sock=None, backlog=100, ssl=None):
"""A coroutine which creates a UNIX Domain Socket server.
The return value is a Server object, which can be used to stop
the service.
The return value is a Server object, which can be used to stop
the service.
path is a str, representing a file systsem path to bind the
server socket to.
path is a str, representing a file systsem path to bind the
server socket to.
sock can optionally be specified in order to use a preexisting
socket object.
sock can optionally be specified in order to use a preexisting
socket object.
backlog is the maximum number of queued connections passed to
listen() (defaults to 100).
backlog is the maximum number of queued connections passed to
listen() (defaults to 100).
ssl can be set to an SSLContext to enable SSL over the
accepted connections.
"""
raise NotImplementedError
ssl can be set to an SSLContext to enable SSL over the
accepted connections.
"""
raise NotImplementedError
def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *,
family=0, proto=0, flags=0):
raise NotImplementedError
def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None,
family=0, proto=0, flags=0):
raise NotImplementedError
# Pipes and subprocesses.
# Pipes and subprocesses.
def connect_read_pipe(self, protocol_factory, pipe):
"""Register read pipe in event loop. Set the pipe to non-blocking mode.
def connect_read_pipe(self, protocol_factory, pipe):
"""Register read pipe in event loop. Set the pipe to non-blocking mode.
protocol_factory should instantiate object with Protocol interface.
pipe is a file-like object.
Return pair (transport, protocol), where transport supports the
ReadTransport interface."""
# The reason to accept file-like object instead of just file descriptor
# is: we need to own pipe and close it at transport finishing
# Can got complicated errors if pass f.fileno(),
# close fd in pipe transport then close f and vise versa.
raise NotImplementedError
protocol_factory should instantiate object with Protocol interface.
pipe is a file-like object.
Return pair (transport, protocol), where transport supports the
ReadTransport interface."""
# The reason to accept file-like object instead of just file descriptor
# is: we need to own pipe and close it at transport finishing
# Can got complicated errors if pass f.fileno(),
# close fd in pipe transport then close f and vise versa.
raise NotImplementedError
def connect_write_pipe(self, protocol_factory, pipe):
"""Register write pipe in event loop.
def connect_write_pipe(self, protocol_factory, pipe):
"""Register write pipe in event loop.
protocol_factory should instantiate object with BaseProtocol interface.
Pipe is file-like object already switched to nonblocking.
Return pair (transport, protocol), where transport support
WriteTransport interface."""
# The reason to accept file-like object instead of just file descriptor
# is: we need to own pipe and close it at transport finishing
# Can got complicated errors if pass f.fileno(),
# close fd in pipe transport then close f and vise versa.
raise NotImplementedError
protocol_factory should instantiate object with BaseProtocol interface.
Pipe is file-like object already switched to nonblocking.
Return pair (transport, protocol), where transport support
WriteTransport interface."""
# The reason to accept file-like object instead of just file descriptor
# is: we need to own pipe and close it at transport finishing
# Can got complicated errors if pass f.fileno(),
# close fd in pipe transport then close f and vise versa.
raise NotImplementedError
def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
**kwargs):
raise NotImplementedError
def subprocess_shell(self, protocol_factory, cmd, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
**kwargs):
raise NotImplementedError
def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
**kwargs):
raise NotImplementedError
def subprocess_exec(self, protocol_factory, *args, **kwargs):
raise NotImplementedError
# Ready-based callback registration methods.
# The add_*() methods return None.
# The remove_*() methods return True if something was removed,
# False if there was nothing to delete.
# Ready-based callback registration methods.
# The add_*() methods return None.
# The remove_*() methods return True if something was removed,
# False if there was nothing to delete.
def add_reader(self, fd, callback, *args):
raise NotImplementedError
def add_reader(self, fd, callback, *args):
raise NotImplementedError
def remove_reader(self, fd):
raise NotImplementedError
def remove_reader(self, fd):
raise NotImplementedError
def add_writer(self, fd, callback, *args):
raise NotImplementedError
def add_writer(self, fd, callback, *args):
raise NotImplementedError
def remove_writer(self, fd):
raise NotImplementedError
def remove_writer(self, fd):
raise NotImplementedError
# Completion based I/O methods returning Futures.
# Completion based I/O methods returning Futures.
def sock_recv(self, sock, nbytes):
raise NotImplementedError
def sock_recv(self, sock, nbytes):
raise NotImplementedError
def sock_sendall(self, sock, data):
raise NotImplementedError
def sock_sendall(self, sock, data):
raise NotImplementedError
def sock_connect(self, sock, address):
raise NotImplementedError
def sock_connect(self, sock, address):
raise NotImplementedError
def sock_accept(self, sock):
raise NotImplementedError
def sock_accept(self, sock):
raise NotImplementedError
# Signal handling.
# Signal handling.
def add_signal_handler(self, sig, callback, *args):
raise NotImplementedError
def add_signal_handler(self, sig, callback, *args):
raise NotImplementedError
def remove_signal_handler(self, sig):
raise NotImplementedError
def remove_signal_handler(self, sig):
raise NotImplementedError
# Task factory.
# Task factory.
def set_task_factory(self, factory):
raise NotImplementedError
def set_task_factory(self, factory):
raise NotImplementedError
def get_task_factory(self):
raise NotImplementedError
def get_task_factory(self):
raise NotImplementedError
# Error handlers.
# Error handlers.
def set_exception_handler(self, handler):
raise NotImplementedError
def set_exception_handler(self, handler):
raise NotImplementedError
def default_exception_handler(self, context):
raise NotImplementedError
def default_exception_handler(self, context):
raise NotImplementedError
def call_exception_handler(self, context):
raise NotImplementedError
def call_exception_handler(self, context):
raise NotImplementedError
# Debug flag management.
# Debug flag management.
def get_debug(self):
raise NotImplementedError
def get_debug(self):
raise NotImplementedError
def set_debug(self, enabled):
raise NotImplementedError
def set_debug(self, enabled):
raise NotImplementedError
class AbstractEventLoopPolicy:
"""Abstract policy for accessing the event loop."""
class AbstractEventLoopPolicy(object):
"""Abstract policy for accessing the event loop."""
def get_event_loop(self):
"""Get the event loop for the current context.
def get_event_loop(self):
"""Get the event loop for the current context.
Returns an event loop object implementing the BaseEventLoop interface,
or raises an exception in case no event loop has been set for the
current context and the current policy does not specify to create one.
Returns an event loop object implementing the BaseEventLoop interface,
or raises an exception in case no event loop has been set for the
current context and the current policy does not specify to create one.
It should never return None."""
raise NotImplementedError
It should never return None."""
raise NotImplementedError
def set_event_loop(self, loop):
"""Set the event loop for the current context to loop."""
raise NotImplementedError
def set_event_loop(self, loop):
"""Set the event loop for the current context to loop."""
raise NotImplementedError
def new_event_loop(self):
"""Create and return a new event loop object according to this
policy's rules. If there's need to set this loop as the event loop for
the current context, set_event_loop must be called explicitly."""
raise NotImplementedError
def new_event_loop(self):
"""Create and return a new event loop object according to this
policy's rules. If there's need to set this loop as the event loop for
the current context, set_event_loop must be called explicitly."""
raise NotImplementedError
# Child processes handling (Unix only).
# Child processes handling (Unix only).
def get_child_watcher(self):
"Get the watcher for child processes."
raise NotImplementedError
def get_child_watcher(self):
"Get the watcher for child processes."
raise NotImplementedError
def set_child_watcher(self, watcher):
"""Set the watcher for child processes."""
raise NotImplementedError
def set_child_watcher(self, watcher):
"""Set the watcher for child processes."""
raise NotImplementedError
class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy):

View File

@@ -5,13 +5,17 @@ __all__ = ['CancelledError', 'TimeoutError',
'Future', 'wrap_future',
]
import concurrent.futures._base
import logging
import reprlib
import sys
import traceback
try:
import reprlib # Python 3
except ImportError:
import repr as reprlib # Python 2
from . import compat
from . import events
from . import executor
# States for Future.
_PENDING = 'PENDING'
@@ -21,9 +25,9 @@ _FINISHED = 'FINISHED'
_PY34 = sys.version_info >= (3, 4)
_PY35 = sys.version_info >= (3, 5)
Error = concurrent.futures._base.Error
CancelledError = concurrent.futures.CancelledError
TimeoutError = concurrent.futures.TimeoutError
Error = executor.Error
CancelledError = executor.CancelledError
TimeoutError = executor.TimeoutError
STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging
@@ -32,7 +36,7 @@ class InvalidStateError(Error):
"""The operation is not allowed in this state."""
class _TracebackLogger:
class _TracebackLogger(object):
"""Helper to log a traceback upon destruction if not cleared.
This solves a nasty problem with Futures and Tasks that have an
@@ -112,7 +116,7 @@ class _TracebackLogger:
self.loop.call_exception_handler({'message': msg})
class Future:
class Future(object):
"""This class is *almost* compatible with concurrent.futures.Future.
Differences:
@@ -138,10 +142,14 @@ class Future:
_blocking = False # proper use of future (yield vs yield from)
# Used by Python 2 to raise the exception with the original traceback
# in the exception() method in debug mode
_exception_tb = None
_log_traceback = False # Used for Python 3.4 and later
_tb_logger = None # Used for Python 3.3 only
def __init__(self, *, loop=None):
def __init__(self, loop=None):
"""Initialize the future.
The optional event_loop argument allows to explicitly set the event
@@ -168,23 +176,23 @@ class Future:
if size == 1:
cb = format_cb(cb[0])
elif size == 2:
cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1]))
cb = '{0}, {1}'.format(format_cb(cb[0]), format_cb(cb[1]))
elif size > 2:
cb = '{}, <{} more>, {}'.format(format_cb(cb[0]),
size-2,
format_cb(cb[-1]))
cb = '{0}, <{1} more>, {2}'.format(format_cb(cb[0]),
size-2,
format_cb(cb[-1]))
return 'cb=[%s]' % cb
def _repr_info(self):
info = [self._state.lower()]
if self._state == _FINISHED:
if self._exception is not None:
info.append('exception={!r}'.format(self._exception))
info.append('exception={0!r}'.format(self._exception))
else:
# use reprlib to limit the length of the output, especially
# for very long strings
result = reprlib.repr(self._result)
info.append('result={}'.format(result))
info.append('result={0}'.format(result))
if self._callbacks:
info.append(self._format_callbacks())
if self._source_traceback:
@@ -272,8 +280,13 @@ class Future:
if self._tb_logger is not None:
self._tb_logger.clear()
self._tb_logger = None
exc_tb = self._exception_tb
self._exception_tb = None
if self._exception is not None:
raise self._exception
if exc_tb is not None:
compat.reraise(type(self._exception), self._exception, exc_tb)
else:
raise self._exception
return self._result
def exception(self):
@@ -292,6 +305,7 @@ class Future:
if self._tb_logger is not None:
self._tb_logger.clear()
self._tb_logger = None
self._exception_tb = None
return self._exception
def add_done_callback(self, fn):
@@ -334,31 +348,61 @@ class Future:
InvalidStateError.
"""
if self._state != _PENDING:
raise InvalidStateError('{}: {!r}'.format(self._state, self))
raise InvalidStateError('{0}: {1!r}'.format(self._state, self))
self._result = result
self._state = _FINISHED
self._schedule_callbacks()
def _get_exception_tb(self):
return self._exception_tb
def set_exception(self, exception):
self._set_exception_with_tb(exception, None)
def _set_exception_with_tb(self, exception, exc_tb):
"""Mark the future done and set an exception.
If the future is already done when this method is called, raises
InvalidStateError.
"""
if self._state != _PENDING:
raise InvalidStateError('{}: {!r}'.format(self._state, self))
raise InvalidStateError('{0}: {1!r}'.format(self._state, self))
if isinstance(exception, type):
exception = exception()
self._exception = exception
if exc_tb is not None:
self._exception_tb = exc_tb
exc_tb = None
elif self._loop.get_debug() and not compat.PY3:
self._exception_tb = sys.exc_info()[2]
self._state = _FINISHED
self._schedule_callbacks()
if _PY34:
self._log_traceback = True
else:
self._tb_logger = _TracebackLogger(self, exception)
# Arrange for the logger to be activated after all callbacks
# have had a chance to call result() or exception().
self._loop.call_soon(self._tb_logger.activate)
if hasattr(exception, '__traceback__'):
# Python 3: exception contains a link to the traceback
# Arrange for the logger to be activated after all callbacks
# have had a chance to call result() or exception().
self._loop.call_soon(self._tb_logger.activate)
else:
if self._loop.get_debug():
frame = sys._getframe(1)
tb = ['Traceback (most recent call last):\n']
if self._exception_tb is not None:
tb += traceback.format_tb(self._exception_tb)
else:
tb += traceback.format_stack(frame)
tb += traceback.format_exception_only(type(exception), exception)
self._tb_logger.tb = tb
else:
self._tb_logger.tb = traceback.format_exception_only(
type(exception),
exception)
self._tb_logger.exc = None
# Truly internal methods.
@@ -392,12 +436,18 @@ class Future:
__await__ = __iter__ # make compatible with 'await' expression
def wrap_future(fut, *, loop=None):
if events.asyncio is not None:
# Accept also asyncio Future objects for interoperability
_FUTURE_CLASSES = (Future, events.asyncio.Future)
else:
_FUTURE_CLASSES = Future
def wrap_future(fut, loop=None):
"""Wrap concurrent.futures.Future object."""
if isinstance(fut, Future):
if isinstance(fut, _FUTURE_CLASSES):
return fut
assert isinstance(fut, concurrent.futures.Future), \
'concurrent.futures.Future is expected, got {!r}'.format(fut)
assert isinstance(fut, executor.Future), \
'concurrent.futures.Future is expected, got {0!r}'.format(fut)
if loop is None:
loop = events.get_event_loop()
new_future = Future(loop=loop)

View File

@@ -7,7 +7,7 @@ import sys
from . import events
from . import futures
from .coroutines import coroutine
from .coroutines import coroutine, From, Return
_PY35 = sys.version_info >= (3, 5)
@@ -19,7 +19,7 @@ class _ContextManager:
This enables the following idiom for acquiring and releasing a
lock around a block:
with (yield from lock):
with (yield From(lock)):
<block>
while failing loudly when accidentally using:
@@ -43,7 +43,7 @@ class _ContextManager:
self._lock = None # Crudely prevent reuse.
class _ContextManagerMixin:
class _ContextManagerMixin(object):
def __enter__(self):
raise RuntimeError(
'"yield from" should be used as context manager expression')
@@ -111,16 +111,16 @@ class Lock(_ContextManagerMixin):
release() call resets the state to unlocked; first coroutine which
is blocked in acquire() is being processed.
acquire() is a coroutine and should be called with 'yield from'.
acquire() is a coroutine and should be called with 'yield From'.
Locks also support the context management protocol. '(yield from lock)'
Locks also support the context management protocol. '(yield From(lock))'
should be used as context manager expression.
Usage:
lock = Lock()
...
yield from lock
yield From(lock)
try:
...
finally:
@@ -130,20 +130,20 @@ class Lock(_ContextManagerMixin):
lock = Lock()
...
with (yield from lock):
with (yield From(lock)):
...
Lock objects can be tested for locking state:
if not lock.locked():
yield from lock
yield From(lock)
else:
# lock is acquired
...
"""
def __init__(self, *, loop=None):
def __init__(self, loop=None):
self._waiters = collections.deque()
self._locked = False
if loop is not None:
@@ -152,11 +152,11 @@ class Lock(_ContextManagerMixin):
self._loop = events.get_event_loop()
def __repr__(self):
res = super().__repr__()
res = super(Lock, self).__repr__()
extra = 'locked' if self._locked else 'unlocked'
if self._waiters:
extra = '{},waiters:{}'.format(extra, len(self._waiters))
return '<{} [{}]>'.format(res[1:-1], extra)
extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
return '<{0} [{1}]>'.format(res[1:-1], extra)
def locked(self):
"""Return True if lock is acquired."""
@@ -171,14 +171,14 @@ class Lock(_ContextManagerMixin):
"""
if not self._waiters and not self._locked:
self._locked = True
return True
raise Return(True)
fut = futures.Future(loop=self._loop)
self._waiters.append(fut)
try:
yield from fut
yield From(fut)
self._locked = True
return True
raise Return(True)
finally:
self._waiters.remove(fut)
@@ -204,7 +204,7 @@ class Lock(_ContextManagerMixin):
raise RuntimeError('Lock is not acquired.')
class Event:
class Event(object):
"""Asynchronous equivalent to threading.Event.
Class implementing event objects. An event manages a flag that can be set
@@ -213,7 +213,7 @@ class Event:
false.
"""
def __init__(self, *, loop=None):
def __init__(self, loop=None):
self._waiters = collections.deque()
self._value = False
if loop is not None:
@@ -222,11 +222,11 @@ class Event:
self._loop = events.get_event_loop()
def __repr__(self):
res = super().__repr__()
res = super(Event, self).__repr__()
extra = 'set' if self._value else 'unset'
if self._waiters:
extra = '{},waiters:{}'.format(extra, len(self._waiters))
return '<{} [{}]>'.format(res[1:-1], extra)
extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
return '<{0} [{1}]>'.format(res[1:-1], extra)
def is_set(self):
"""Return True if and only if the internal flag is true."""
@@ -259,13 +259,13 @@ class Event:
set() to set the flag to true, then return True.
"""
if self._value:
return True
raise Return(True)
fut = futures.Future(loop=self._loop)
self._waiters.append(fut)
try:
yield from fut
return True
yield From(fut)
raise Return(True)
finally:
self._waiters.remove(fut)
@@ -280,7 +280,7 @@ class Condition(_ContextManagerMixin):
A new Lock object is created and used as the underlying lock.
"""
def __init__(self, lock=None, *, loop=None):
def __init__(self, lock=None, loop=None):
if loop is not None:
self._loop = loop
else:
@@ -300,11 +300,11 @@ class Condition(_ContextManagerMixin):
self._waiters = collections.deque()
def __repr__(self):
res = super().__repr__()
res = super(Condition, self).__repr__()
extra = 'locked' if self.locked() else 'unlocked'
if self._waiters:
extra = '{},waiters:{}'.format(extra, len(self._waiters))
return '<{} [{}]>'.format(res[1:-1], extra)
extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
return '<{0} [{1}]>'.format(res[1:-1], extra)
@coroutine
def wait(self):
@@ -326,13 +326,13 @@ class Condition(_ContextManagerMixin):
fut = futures.Future(loop=self._loop)
self._waiters.append(fut)
try:
yield from fut
return True
yield From(fut)
raise Return(True)
finally:
self._waiters.remove(fut)
finally:
yield from self.acquire()
yield From(self.acquire())
@coroutine
def wait_for(self, predicate):
@@ -344,9 +344,9 @@ class Condition(_ContextManagerMixin):
"""
result = predicate()
while not result:
yield from self.wait()
yield From(self.wait())
result = predicate()
return result
raise Return(result)
def notify(self, n=1):
"""By default, wake up one coroutine waiting on this condition, if any.
@@ -396,7 +396,7 @@ class Semaphore(_ContextManagerMixin):
ValueError is raised.
"""
def __init__(self, value=1, *, loop=None):
def __init__(self, value=1, loop=None):
if value < 0:
raise ValueError("Semaphore initial value must be >= 0")
self._value = value
@@ -407,12 +407,12 @@ class Semaphore(_ContextManagerMixin):
self._loop = events.get_event_loop()
def __repr__(self):
res = super().__repr__()
extra = 'locked' if self.locked() else 'unlocked,value:{}'.format(
res = super(Semaphore, self).__repr__()
extra = 'locked' if self.locked() else 'unlocked,value:{0}'.format(
self._value)
if self._waiters:
extra = '{},waiters:{}'.format(extra, len(self._waiters))
return '<{} [{}]>'.format(res[1:-1], extra)
extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
return '<{0} [{1}]>'.format(res[1:-1], extra)
def locked(self):
"""Returns True if semaphore can not be acquired immediately."""
@@ -430,14 +430,14 @@ class Semaphore(_ContextManagerMixin):
"""
if not self._waiters and self._value > 0:
self._value -= 1
return True
raise Return(True)
fut = futures.Future(loop=self._loop)
self._waiters.append(fut)
try:
yield from fut
yield From(fut)
self._value -= 1
return True
raise Return(True)
finally:
self._waiters.remove(fut)
@@ -460,11 +460,11 @@ class BoundedSemaphore(Semaphore):
above the initial value.
"""
def __init__(self, value=1, *, loop=None):
def __init__(self, value=1, loop=None):
self._bound_value = value
super().__init__(value, loop=loop)
super(BoundedSemaphore, self).__init__(value, loop=loop)
def release(self):
if self._value >= self._bound_value:
raise ValueError('BoundedSemaphore released too many times')
super().release()
super(BoundedSemaphore, self).release()

View File

@@ -16,6 +16,9 @@ from . import futures
from . import sslproto
from . import transports
from .log import logger
from .compat import flatten_bytes
from .py33_exceptions import (BrokenPipeError,
ConnectionAbortedError, ConnectionResetError)
class _ProactorBasePipeTransport(transports._FlowControlMixin,
@@ -24,7 +27,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None):
super().__init__(extra, loop)
super(_ProactorBasePipeTransport, self).__init__(extra, loop)
self._set_extra(sock)
self._sock = sock
self._protocol = protocol
@@ -143,7 +146,8 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None):
super().__init__(loop, sock, protocol, waiter, extra, server)
super(_ProactorReadPipeTransport, self).__init__(loop, sock, protocol,
waiter, extra, server)
self._paused = False
self._loop.call_soon(self._loop_reading)
@@ -220,9 +224,7 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
"""Transport for write pipes."""
def write(self, data):
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)',
type(data))
data = flatten_bytes(data)
if self._eof_written:
raise RuntimeError('write_eof() already called')
@@ -301,7 +303,7 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport):
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
super(_ProactorWritePipeTransport, self).__init__(*args, **kw)
self._read_fut = self._loop._proactor.recv(self._sock, 16)
self._read_fut.add_done_callback(self._pipe_closed)
@@ -368,7 +370,7 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport,
class BaseProactorEventLoop(base_events.BaseEventLoop):
def __init__(self, proactor):
super().__init__()
super(BaseProactorEventLoop, self).__init__()
logger.debug('Using proactor: %s', proactor.__class__.__name__)
self._proactor = proactor
self._selector = proactor # convenient alias
@@ -383,7 +385,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
extra, server)
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
server_side=False, server_hostname=None,
extra=None, server=None):
if not sslproto._is_sslproto_available():
raise NotImplementedError("Proactor event loop requires Python 3.5"
@@ -427,7 +429,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self._selector = None
# Close the event loop
super().close()
super(BaseProactorEventLoop, self).close()
def sock_recv(self, sock, n):
return self._proactor.recv(sock, n)

View File

@@ -4,7 +4,7 @@ __all__ = ['BaseProtocol', 'Protocol', 'DatagramProtocol',
'SubprocessProtocol']
class BaseProtocol:
class BaseProtocol(object):
"""Common base class for protocol interfaces.
Usually user implements protocols that derived from BaseProtocol

View File

@@ -9,6 +9,7 @@ import heapq
from . import events
from . import futures
from . import locks
from .coroutines import From, Return
from .tasks import coroutine
@@ -26,7 +27,7 @@ class QueueFull(Exception):
pass
class Queue:
class Queue(object):
"""A queue, useful for coordinating producer and consumer coroutines.
If maxsize is less than or equal to zero, the queue size is infinite. If it
@@ -38,7 +39,7 @@ class Queue:
interrupted between calling qsize() and doing an operation on the Queue.
"""
def __init__(self, maxsize=0, *, loop=None):
def __init__(self, maxsize=0, loop=None):
if loop is None:
self._loop = events.get_event_loop()
else:
@@ -73,22 +74,22 @@ class Queue:
self._finished.clear()
def __repr__(self):
return '<{} at {:#x} {}>'.format(
return '<{0} at {1:#x} {2}>'.format(
type(self).__name__, id(self), self._format())
def __str__(self):
return '<{} {}>'.format(type(self).__name__, self._format())
return '<{0} {1}>'.format(type(self).__name__, self._format())
def _format(self):
result = 'maxsize={!r}'.format(self._maxsize)
result = 'maxsize={0!r}'.format(self._maxsize)
if getattr(self, '_queue', None):
result += ' _queue={!r}'.format(list(self._queue))
result += ' _queue={0!r}'.format(list(self._queue))
if self._getters:
result += ' _getters[{}]'.format(len(self._getters))
result += ' _getters[{0}]'.format(len(self._getters))
if self._putters:
result += ' _putters[{}]'.format(len(self._putters))
result += ' _putters[{0}]'.format(len(self._putters))
if self._unfinished_tasks:
result += ' tasks={}'.format(self._unfinished_tasks)
result += ' tasks={0}'.format(self._unfinished_tasks)
return result
def _consume_done_getters(self):
@@ -149,7 +150,7 @@ class Queue:
waiter = futures.Future(loop=self._loop)
self._putters.append((item, waiter))
yield from waiter
yield From(waiter)
else:
self.__put_internal(item)
@@ -195,15 +196,16 @@ class Queue:
# ChannelTest.test_wait.
self._loop.call_soon(putter._set_result_unless_cancelled, None)
return self._get()
raise Return(self._get())
elif self.qsize():
return self._get()
raise Return(self._get())
else:
waiter = futures.Future(loop=self._loop)
self._getters.append(waiter)
return (yield from waiter)
result = yield From(waiter)
raise Return(result)
def get_nowait(self):
"""Remove and return an item from the queue.
@@ -257,7 +259,7 @@ class Queue:
When the count of unfinished tasks drops to zero, join() unblocks.
"""
if self._unfinished_tasks > 0:
yield from self._finished.wait()
yield From(self._finished.wait())
class PriorityQueue(Queue):

View File

@@ -14,6 +14,8 @@ import sys
import warnings
try:
import ssl
from .py3_ssl import (wrap_ssl_error, SSLContext, SSLWantReadError,
SSLWantWriteError)
except ImportError: # pragma: no cover
ssl = None
@@ -24,8 +26,26 @@ from . import futures
from . import selectors
from . import transports
from . import sslproto
from .compat import flatten_bytes
from .coroutines import coroutine
from .log import logger
from .py33_exceptions import (wrap_error,
BlockingIOError, InterruptedError, ConnectionAbortedError, BrokenPipeError,
ConnectionResetError)
# On Mac OS 10.6 with Python 2.6.1 or OpenIndiana 148 with Python 2.6.4,
# _SelectorSslTransport._read_ready() hangs if the socket has no data.
# Example: test_events.test_create_server_ssl()
_SSL_REQUIRES_SELECT = (sys.version_info < (2, 6, 6))
if _SSL_REQUIRES_SELECT:
import select
def _get_socket_error(sock, address):
err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
# Jump to the except clause below.
raise OSError(err, 'Connect call failed %s' % (address,))
def _test_selector_event(selector, fd, event):
@@ -46,7 +66,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
"""
def __init__(self, selector=None):
super().__init__()
super(BaseSelectorEventLoop, self).__init__()
if selector is None:
selector = selectors.DefaultSelector()
@@ -54,13 +74,13 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
self._selector = selector
self._make_self_pipe()
def _make_socket_transport(self, sock, protocol, waiter=None, *,
def _make_socket_transport(self, sock, protocol, waiter=None,
extra=None, server=None):
return _SelectorSocketTransport(self, sock, protocol, waiter,
extra, server)
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
server_side=False, server_hostname=None,
extra=None, server=None):
if not sslproto._is_sslproto_available():
return self._make_legacy_ssl_transport(
@@ -75,7 +95,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
return ssl_protocol._app_transport
def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext,
waiter, *,
waiter,
server_side=False, server_hostname=None,
extra=None, server=None):
# Use the legacy API: SSL_write, SSL_read, etc. The legacy API is used
@@ -95,7 +115,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
if self.is_closed():
return
self._close_self_pipe()
super().close()
super(BaseSelectorEventLoop, self).close()
if self._selector is not None:
self._selector.close()
self._selector = None
@@ -125,7 +145,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def _read_from_self(self):
while True:
try:
data = self._ssock.recv(4096)
data = wrap_error(self._ssock.recv, 4096)
if not data:
break
self._process_self_data(data)
@@ -143,7 +163,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
csock = self._csock
if csock is not None:
try:
csock.send(b'\0')
wrap_error(csock.send, b'\0')
except OSError:
if self._debug:
logger.debug("Fail to write a null byte into the "
@@ -158,14 +178,14 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def _accept_connection(self, protocol_factory, sock,
sslcontext=None, server=None):
try:
conn, addr = sock.accept()
conn, addr = wrap_error(sock.accept)
if self._debug:
logger.debug("%r got a new connection from %r: %r",
server, addr, conn)
conn.setblocking(False)
except (BlockingIOError, InterruptedError, ConnectionAbortedError):
pass # False alarm.
except OSError as exc:
except socket.error as exc:
# There's nowhere to send the error, so just log it.
if exc.errno in (errno.EMFILE, errno.ENFILE,
errno.ENOBUFS, errno.ENOMEM):
@@ -331,7 +351,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
if fut.cancelled():
return
try:
data = sock.recv(n)
data = wrap_error(sock.recv, n)
except (BlockingIOError, InterruptedError):
self.add_reader(fd, self._sock_recv, fut, True, sock, n)
except Exception as exc:
@@ -368,7 +388,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
return
try:
n = sock.send(data)
n = wrap_error(sock.send, data)
except (BlockingIOError, InterruptedError):
n = 0
except Exception as exc:
@@ -408,7 +428,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def _sock_connect(self, fut, sock, address):
fd = sock.fileno()
try:
sock.connect(address)
wrap_error(sock.connect, address)
except (BlockingIOError, InterruptedError):
# Issue #23618: When the C function connect() fails with EINTR, the
# connection runs in background. We have to wait until the socket
@@ -430,10 +450,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
return
try:
err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
# Jump to any except clause below.
raise OSError(err, 'Connect call failed %s' % (address,))
wrap_error(_get_socket_error, sock, address)
except (BlockingIOError, InterruptedError):
# socket is still registered, the callback will be retried later
pass
@@ -465,7 +482,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
if fut.cancelled():
return
try:
conn, address = sock.accept()
conn, address = wrap_error(sock.accept)
conn.setblocking(False)
except (BlockingIOError, InterruptedError):
self.add_reader(fd, self._sock_accept, fut, True, sock)
@@ -506,7 +523,7 @@ class _SelectorTransport(transports._FlowControlMixin,
_sock = None
def __init__(self, loop, sock, protocol, extra=None, server=None):
super().__init__(extra, loop)
super(_SelectorTransport, self).__init__(extra, loop)
self._extra['socket'] = sock
self._extra['sockname'] = sock.getsockname()
if 'peername' not in self._extra:
@@ -593,7 +610,7 @@ class _SelectorTransport(transports._FlowControlMixin,
if self._conn_lost:
return
if self._buffer:
self._buffer.clear()
del self._buffer[:]
self._loop.remove_writer(self._sock_fd)
if not self._closing:
self._closing = True
@@ -623,7 +640,7 @@ class _SelectorSocketTransport(_SelectorTransport):
def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None):
super().__init__(loop, sock, protocol, extra, server)
super(_SelectorSocketTransport, self).__init__(loop, sock, protocol, extra, server)
self._eof = False
self._paused = False
@@ -657,7 +674,7 @@ class _SelectorSocketTransport(_SelectorTransport):
def _read_ready(self):
try:
data = self._sock.recv(self.max_size)
data = wrap_error(self._sock.recv, self.max_size)
except (BlockingIOError, InterruptedError):
pass
except Exception as exc:
@@ -678,9 +695,7 @@ class _SelectorSocketTransport(_SelectorTransport):
self.close()
def write(self, data):
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)',
type(data))
data = flatten_bytes(data)
if self._eof:
raise RuntimeError('Cannot call write() after write_eof()')
if not data:
@@ -695,7 +710,7 @@ class _SelectorSocketTransport(_SelectorTransport):
if not self._buffer:
# Optimization: try to send now.
try:
n = self._sock.send(data)
n = wrap_error(self._sock.send, data)
except (BlockingIOError, InterruptedError):
pass
except Exception as exc:
@@ -715,13 +730,14 @@ class _SelectorSocketTransport(_SelectorTransport):
def _write_ready(self):
assert self._buffer, 'Data should not be empty'
data = flatten_bytes(self._buffer)
try:
n = self._sock.send(self._buffer)
n = wrap_error(self._sock.send, data)
except (BlockingIOError, InterruptedError):
pass
except Exception as exc:
self._loop.remove_writer(self._sock_fd)
self._buffer.clear()
del self._buffer[:]
self._fatal_error(exc, 'Fatal write error on socket transport')
else:
if n:
@@ -766,7 +782,7 @@ class _SelectorSslTransport(_SelectorTransport):
wrap_kwargs['server_hostname'] = server_hostname
sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs)
super().__init__(loop, sslsock, protocol, extra, server)
super(_SelectorSslTransport, self).__init__(loop, sslsock, protocol, extra, server)
# the protocol connection is only made after the SSL handshake
self._protocol_connected = False
@@ -797,12 +813,12 @@ class _SelectorSslTransport(_SelectorTransport):
def _on_handshake(self, start_time):
try:
self._sock.do_handshake()
except ssl.SSLWantReadError:
wrap_ssl_error(self._sock.do_handshake)
except SSLWantReadError:
self._loop.add_reader(self._sock_fd,
self._on_handshake, start_time)
return
except ssl.SSLWantWriteError:
except SSLWantWriteError:
self._loop.add_writer(self._sock_fd,
self._on_handshake, start_time)
return
@@ -842,8 +858,9 @@ class _SelectorSslTransport(_SelectorTransport):
# Add extra info that becomes available after handshake.
self._extra.update(peercert=peercert,
cipher=self._sock.cipher(),
compression=self._sock.compression(),
)
if hasattr(self._sock, 'compression'):
self._extra['compression'] = self._sock.compression()
self._read_wants_write = False
self._write_wants_read = False
@@ -883,6 +900,9 @@ class _SelectorSslTransport(_SelectorTransport):
if self._loop.get_debug():
logger.debug("%r resumes reading", self)
def _sock_recv(self):
return wrap_ssl_error(self._sock.recv, self.max_size)
def _read_ready(self):
if self._write_wants_read:
self._write_wants_read = False
@@ -892,10 +912,16 @@ class _SelectorSslTransport(_SelectorTransport):
self._loop.add_writer(self._sock_fd, self._write_ready)
try:
data = self._sock.recv(self.max_size)
except (BlockingIOError, InterruptedError, ssl.SSLWantReadError):
if _SSL_REQUIRES_SELECT:
rfds = (self._sock.fileno(),)
rfds = select.select(rfds, (), (), 0.0)[0]
if not rfds:
# False alarm.
return
data = wrap_error(self._sock_recv)
except (BlockingIOError, InterruptedError, SSLWantReadError):
pass
except ssl.SSLWantWriteError:
except SSLWantWriteError:
self._read_wants_write = True
self._loop.remove_reader(self._sock_fd)
self._loop.add_writer(self._sock_fd, self._write_ready)
@@ -924,17 +950,18 @@ class _SelectorSslTransport(_SelectorTransport):
self._loop.add_reader(self._sock_fd, self._read_ready)
if self._buffer:
data = flatten_bytes(self._buffer)
try:
n = self._sock.send(self._buffer)
except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError):
n = wrap_error(self._sock.send, data)
except (BlockingIOError, InterruptedError, SSLWantWriteError):
n = 0
except ssl.SSLWantReadError:
except SSLWantReadError:
n = 0
self._loop.remove_writer(self._sock_fd)
self._write_wants_read = True
except Exception as exc:
self._loop.remove_writer(self._sock_fd)
self._buffer.clear()
del self._buffer[:]
self._fatal_error(exc, 'Fatal write error on SSL transport')
return
@@ -949,9 +976,7 @@ class _SelectorSslTransport(_SelectorTransport):
self._call_connection_lost(None)
def write(self, data):
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)',
type(data))
data = flatten_bytes(data)
if not data:
return
@@ -978,7 +1003,8 @@ class _SelectorDatagramTransport(_SelectorTransport):
def __init__(self, loop, sock, protocol, address=None,
waiter=None, extra=None):
super().__init__(loop, sock, protocol, extra)
super(_SelectorDatagramTransport, self).__init__(loop, sock,
protocol, extra)
self._address = address
self._loop.call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
@@ -993,7 +1019,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
def _read_ready(self):
try:
data, addr = self._sock.recvfrom(self.max_size)
data, addr = wrap_error(self._sock.recvfrom, self.max_size)
except (BlockingIOError, InterruptedError):
pass
except OSError as exc:
@@ -1004,9 +1030,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
self._protocol.datagram_received(data, addr)
def sendto(self, data, addr=None):
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)',
type(data))
data = flatten_bytes(data)
if not data:
return
@@ -1024,9 +1048,9 @@ class _SelectorDatagramTransport(_SelectorTransport):
# Attempt to send it right away first.
try:
if self._address:
self._sock.send(data)
wrap_error(self._sock.send, data)
else:
self._sock.sendto(data, addr)
wrap_error(self._sock.sendto, data, addr)
return
except (BlockingIOError, InterruptedError):
self._loop.add_writer(self._sock_fd, self._sendto_ready)
@@ -1047,9 +1071,9 @@ class _SelectorDatagramTransport(_SelectorTransport):
data, addr = self._buffer.popleft()
try:
if self._address:
self._sock.send(data)
wrap_error(self._sock.send, data)
else:
self._sock.sendto(data, addr)
wrap_error(self._sock.sendto, data, addr)
except (BlockingIOError, InterruptedError):
self._buffer.appendleft((data, addr)) # Try again later.
break

View File

@@ -11,6 +11,9 @@ import math
import select
import sys
from .py33_exceptions import wrap_error, InterruptedError
from .compat import integer_types
# generic events, that must be mapped to implementation-specific ones
EVENT_READ = (1 << 0)
@@ -29,16 +32,16 @@ def _fileobj_to_fd(fileobj):
Raises:
ValueError if the object is invalid
"""
if isinstance(fileobj, int):
if isinstance(fileobj, integer_types):
fd = fileobj
else:
try:
fd = int(fileobj.fileno())
except (AttributeError, TypeError, ValueError):
raise ValueError("Invalid file object: "
"{!r}".format(fileobj)) from None
"{0!r}".format(fileobj))
if fd < 0:
raise ValueError("Invalid file descriptor: {}".format(fd))
raise ValueError("Invalid file descriptor: {0}".format(fd))
return fd
@@ -61,13 +64,13 @@ class _SelectorMapping(Mapping):
fd = self._selector._fileobj_lookup(fileobj)
return self._selector._fd_to_key[fd]
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj)) from None
raise KeyError("{0!r} is not registered".format(fileobj))
def __iter__(self):
return iter(self._selector._fd_to_key)
class BaseSelector(metaclass=ABCMeta):
class BaseSelector(object):
"""Selector abstract base class.
A selector supports registering file objects to be monitored for specific
@@ -81,6 +84,7 @@ class BaseSelector(metaclass=ABCMeta):
depending on the platform. The default `Selector` class uses the most
efficient implementation on the current platform.
"""
__metaclass__ = ABCMeta
@abstractmethod
def register(self, fileobj, events, data=None):
@@ -179,7 +183,7 @@ class BaseSelector(metaclass=ABCMeta):
try:
return mapping[fileobj]
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj)) from None
raise KeyError("{0!r} is not registered".format(fileobj))
@abstractmethod
def get_map(self):
@@ -223,12 +227,12 @@ class _BaseSelectorImpl(BaseSelector):
def register(self, fileobj, events, data=None):
if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)):
raise ValueError("Invalid events: {!r}".format(events))
raise ValueError("Invalid events: {0!r}".format(events))
key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
if key.fd in self._fd_to_key:
raise KeyError("{!r} (FD {}) is already registered"
raise KeyError("{0!r} (FD {1}) is already registered"
.format(fileobj, key.fd))
self._fd_to_key[key.fd] = key
@@ -238,7 +242,7 @@ class _BaseSelectorImpl(BaseSelector):
try:
key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj)) from None
raise KeyError("{0!r} is not registered".format(fileobj))
return key
def modify(self, fileobj, events, data=None):
@@ -246,7 +250,7 @@ class _BaseSelectorImpl(BaseSelector):
try:
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj)) from None
raise KeyError("{0!r} is not registered".format(fileobj))
if events != key.events:
self.unregister(fileobj)
key = self.register(fileobj, events, data)
@@ -282,12 +286,12 @@ class SelectSelector(_BaseSelectorImpl):
"""Select-based selector."""
def __init__(self):
super().__init__()
super(SelectSelector, self).__init__()
self._readers = set()
self._writers = set()
def register(self, fileobj, events, data=None):
key = super().register(fileobj, events, data)
key = super(SelectSelector, self).register(fileobj, events, data)
if events & EVENT_READ:
self._readers.add(key.fd)
if events & EVENT_WRITE:
@@ -295,7 +299,7 @@ class SelectSelector(_BaseSelectorImpl):
return key
def unregister(self, fileobj):
key = super().unregister(fileobj)
key = super(SelectSelector, self).unregister(fileobj)
self._readers.discard(key.fd)
self._writers.discard(key.fd)
return key
@@ -311,7 +315,8 @@ class SelectSelector(_BaseSelectorImpl):
timeout = None if timeout is None else max(timeout, 0)
ready = []
try:
r, w, _ = self._select(self._readers, self._writers, [], timeout)
r, w, _ = wrap_error(self._select,
self._readers, self._writers, [], timeout)
except InterruptedError:
return ready
r = set(r)
@@ -335,11 +340,11 @@ if hasattr(select, 'poll'):
"""Poll-based selector."""
def __init__(self):
super().__init__()
super(PollSelector, self).__init__()
self._poll = select.poll()
def register(self, fileobj, events, data=None):
key = super().register(fileobj, events, data)
key = super(PollSelector, self).register(fileobj, events, data)
poll_events = 0
if events & EVENT_READ:
poll_events |= select.POLLIN
@@ -349,7 +354,7 @@ if hasattr(select, 'poll'):
return key
def unregister(self, fileobj):
key = super().unregister(fileobj)
key = super(PollSelector, self).unregister(fileobj)
self._poll.unregister(key.fd)
return key
@@ -361,10 +366,10 @@ if hasattr(select, 'poll'):
else:
# poll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
timeout = int(math.ceil(timeout * 1e3))
ready = []
try:
fd_event_list = self._poll.poll(timeout)
fd_event_list = wrap_error(self._poll.poll, timeout)
except InterruptedError:
return ready
for fd, event in fd_event_list:
@@ -386,14 +391,14 @@ if hasattr(select, 'epoll'):
"""Epoll-based selector."""
def __init__(self):
super().__init__()
super(EpollSelector, self).__init__()
self._epoll = select.epoll()
def fileno(self):
return self._epoll.fileno()
def register(self, fileobj, events, data=None):
key = super().register(fileobj, events, data)
key = super(EpollSelector, self).register(fileobj, events, data)
epoll_events = 0
if events & EVENT_READ:
epoll_events |= select.EPOLLIN
@@ -403,7 +408,7 @@ if hasattr(select, 'epoll'):
return key
def unregister(self, fileobj):
key = super().unregister(fileobj)
key = super(EpollSelector, self).unregister(fileobj)
try:
self._epoll.unregister(key.fd)
except OSError:
@@ -429,7 +434,7 @@ if hasattr(select, 'epoll'):
ready = []
try:
fd_event_list = self._epoll.poll(timeout, max_ev)
fd_event_list = wrap_error(self._epoll.poll, timeout, max_ev)
except InterruptedError:
return ready
for fd, event in fd_event_list:
@@ -446,7 +451,7 @@ if hasattr(select, 'epoll'):
def close(self):
self._epoll.close()
super().close()
super(EpollSelector, self).close()
if hasattr(select, 'devpoll'):
@@ -455,14 +460,14 @@ if hasattr(select, 'devpoll'):
"""Solaris /dev/poll selector."""
def __init__(self):
super().__init__()
super(DevpollSelector, self).__init__()
self._devpoll = select.devpoll()
def fileno(self):
return self._devpoll.fileno()
def register(self, fileobj, events, data=None):
key = super().register(fileobj, events, data)
key = super(DevpollSelector, self).register(fileobj, events, data)
poll_events = 0
if events & EVENT_READ:
poll_events |= select.POLLIN
@@ -472,7 +477,7 @@ if hasattr(select, 'devpoll'):
return key
def unregister(self, fileobj):
key = super().unregister(fileobj)
key = super(DevpollSelector, self).unregister(fileobj)
self._devpoll.unregister(key.fd)
return key
@@ -504,7 +509,7 @@ if hasattr(select, 'devpoll'):
def close(self):
self._devpoll.close()
super().close()
super(DevpollSelector, self).close()
if hasattr(select, 'kqueue'):
@@ -513,14 +518,14 @@ if hasattr(select, 'kqueue'):
"""Kqueue-based selector."""
def __init__(self):
super().__init__()
super(KqueueSelector, self).__init__()
self._kqueue = select.kqueue()
def fileno(self):
return self._kqueue.fileno()
def register(self, fileobj, events, data=None):
key = super().register(fileobj, events, data)
key = super(KqueueSelector, self).register(fileobj, events, data)
if events & EVENT_READ:
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
select.KQ_EV_ADD)
@@ -532,7 +537,7 @@ if hasattr(select, 'kqueue'):
return key
def unregister(self, fileobj):
key = super().unregister(fileobj)
key = super(KqueueSelector, self).unregister(fileobj)
if key.events & EVENT_READ:
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
select.KQ_EV_DELETE)
@@ -557,7 +562,8 @@ if hasattr(select, 'kqueue'):
max_ev = len(self._fd_to_key)
ready = []
try:
kev_list = self._kqueue.control(None, max_ev, timeout)
kev_list = wrap_error(self._kqueue.control,
None, max_ev, timeout)
except InterruptedError:
return ready
for kev in kev_list:
@@ -576,7 +582,7 @@ if hasattr(select, 'kqueue'):
def close(self):
self._kqueue.close()
super().close()
super(KqueueSelector, self).close()
# Choose the best implementation, roughly:

View File

@@ -9,6 +9,7 @@ except ImportError: # pragma: no cover
from . import protocols
from . import transports
from .log import logger
from .py3_ssl import BACKPORT_SSL_CONTEXT
def _create_transport_context(server_side, server_hostname):
@@ -26,10 +27,11 @@ def _create_transport_context(server_side, server_hostname):
else:
# Fallback for Python 3.3.
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.set_default_verify_paths()
sslcontext.verify_mode = ssl.CERT_REQUIRED
if not BACKPORT_SSL_CONTEXT:
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.set_default_verify_paths()
sslcontext.verify_mode = ssl.CERT_REQUIRED
return sslcontext
@@ -43,6 +45,11 @@ _DO_HANDSHAKE = "DO_HANDSHAKE"
_WRAPPED = "WRAPPED"
_SHUTDOWN = "SHUTDOWN"
if hasattr(ssl, 'CertificateError'):
_SSL_ERRORS = (ssl.SSLError, ssl.CertificateError)
else:
_SSL_ERRORS = ssl.SSLError
class _SSLPipe(object):
"""An SSL "Pipe".
@@ -224,7 +231,7 @@ class _SSLPipe(object):
elif self._state == _UNWRAPPED:
# Drain possible plaintext data after close_notify.
appdata.append(self._incoming.read())
except (ssl.SSLError, ssl.CertificateError) as exc:
except _SSL_ERRORS as exc:
if getattr(exc, 'errno', None) not in (
ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
ssl.SSL_ERROR_SYSCALL):
@@ -569,7 +576,8 @@ class SSLProtocol(protocols.Protocol):
ssl.match_hostname(peercert, self._server_hostname)
except BaseException as exc:
if self._loop.get_debug():
if isinstance(exc, ssl.CertificateError):
if (hasattr(ssl, 'CertificateError')
and isinstance(exc, ssl.CertificateError)):
logger.warning("%r: SSL handshake failed "
"on verifying the certificate",
self, exc_info=True)

View File

@@ -15,7 +15,8 @@ from . import coroutines
from . import events
from . import futures
from . import protocols
from .coroutines import coroutine
from .coroutines import coroutine, From, Return
from .py33_exceptions import ConnectionResetError
from .log import logger
@@ -38,7 +39,7 @@ class IncompleteReadError(EOFError):
@coroutine
def open_connection(host=None, port=None, *,
def open_connection(host=None, port=None,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""A wrapper for create_connection() returning a (reader, writer) pair.
@@ -61,14 +62,14 @@ def open_connection(host=None, port=None, *,
loop = events.get_event_loop()
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop)
transport, _ = yield from loop.create_connection(
lambda: protocol, host, port, **kwds)
transport, _ = yield From(loop.create_connection(
lambda: protocol, host, port, **kwds))
writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
raise Return(reader, writer)
@coroutine
def start_server(client_connected_cb, host=None, port=None, *,
def start_server(client_connected_cb, host=None, port=None,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Start a socket server, call back for each client connected.
@@ -100,28 +101,29 @@ def start_server(client_connected_cb, host=None, port=None, *,
loop=loop)
return protocol
return (yield from loop.create_server(factory, host, port, **kwds))
server = yield From(loop.create_server(factory, host, port, **kwds))
raise Return(server)
if hasattr(socket, 'AF_UNIX'):
# UNIX Domain Sockets are supported on this platform
@coroutine
def open_unix_connection(path=None, *,
def open_unix_connection(path=None,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
if loop is None:
loop = events.get_event_loop()
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop)
transport, _ = yield from loop.create_unix_connection(
lambda: protocol, path, **kwds)
transport, _ = yield From(loop.create_unix_connection(
lambda: protocol, path, **kwds))
writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
raise Return(reader, writer)
@coroutine
def start_unix_server(client_connected_cb, path=None, *,
def start_unix_server(client_connected_cb, path=None,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `start_server` but works with UNIX Domain Sockets."""
if loop is None:
@@ -133,7 +135,8 @@ if hasattr(socket, 'AF_UNIX'):
loop=loop)
return protocol
return (yield from loop.create_unix_server(factory, path, **kwds))
server = (yield From(loop.create_unix_server(factory, path, **kwds)))
raise Return(server)
class FlowControlMixin(protocols.Protocol):
@@ -199,7 +202,7 @@ class FlowControlMixin(protocols.Protocol):
assert waiter is None or waiter.cancelled()
waiter = futures.Future(loop=self._loop)
self._drain_waiter = waiter
yield from waiter
yield From(waiter)
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
@@ -212,7 +215,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
"""
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
super().__init__(loop=loop)
super(StreamReaderProtocol, self).__init__(loop=loop)
self._stream_reader = stream_reader
self._stream_writer = None
self._client_connected_cb = client_connected_cb
@@ -233,7 +236,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
self._stream_reader.feed_eof()
else:
self._stream_reader.set_exception(exc)
super().connection_lost(exc)
super(StreamReaderProtocol, self).connection_lost(exc)
def data_received(self, data):
self._stream_reader.feed_data(data)
@@ -242,7 +245,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
self._stream_reader.feed_eof()
class StreamWriter:
class StreamWriter(object):
"""Wraps a Transport.
This exposes write(), writelines(), [can_]write_eof(),
@@ -295,16 +298,16 @@ class StreamWriter:
The intended use is to write
w.write(data)
yield from w.drain()
yield From(w.drain())
"""
if self._reader is not None:
exc = self._reader.exception()
if exc is not None:
raise exc
yield from self._protocol._drain_helper()
yield From(self._protocol._drain_helper())
class StreamReader:
class StreamReader(object):
def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
# The line length limit is a security feature;
@@ -391,9 +394,16 @@ class StreamReader:
raise RuntimeError('%s() called while another coroutine is '
'already waiting for incoming data' % func_name)
# In asyncio, there is no need to recheck if we got data or EOF thanks
# to "yield from". In trollius, a StreamReader method can be called
# after the _wait_for_data() coroutine is scheduled and before it is
# really executed.
if self._buffer or self._eof:
return
self._waiter = futures.Future(loop=self._loop)
try:
yield from self._waiter
yield From(self._waiter)
finally:
self._waiter = None
@@ -410,7 +420,7 @@ class StreamReader:
ichar = self._buffer.find(b'\n')
if ichar < 0:
line.extend(self._buffer)
self._buffer.clear()
del self._buffer[:]
else:
ichar += 1
line.extend(self._buffer[:ichar])
@@ -425,10 +435,10 @@ class StreamReader:
break
if not_enough:
yield from self._wait_for_data('readline')
yield From(self._wait_for_data('readline'))
self._maybe_resume_transport()
return bytes(line)
raise Return(bytes(line))
@coroutine
def read(self, n=-1):
@@ -436,7 +446,7 @@ class StreamReader:
raise self._exception
if not n:
return b''
raise Return(b'')
if n < 0:
# This used to just loop creating a new waiter hoping to
@@ -445,25 +455,25 @@ class StreamReader:
# bytes. So just call self.read(self._limit) until EOF.
blocks = []
while True:
block = yield from self.read(self._limit)
block = yield From(self.read(self._limit))
if not block:
break
blocks.append(block)
return b''.join(blocks)
raise Return(b''.join(blocks))
else:
if not self._buffer and not self._eof:
yield from self._wait_for_data('read')
yield From(self._wait_for_data('read'))
if n < 0 or len(self._buffer) <= n:
data = bytes(self._buffer)
self._buffer.clear()
del self._buffer[:]
else:
# n > 0 and len(self._buffer) > n
data = bytes(self._buffer[:n])
del self._buffer[:n]
self._maybe_resume_transport()
return data
raise Return(data)
@coroutine
def readexactly(self, n):
@@ -479,14 +489,14 @@ class StreamReader:
blocks = []
while n > 0:
block = yield from self.read(n)
block = yield From(self.read(n))
if not block:
partial = b''.join(blocks)
raise IncompleteReadError(partial, len(partial) + n)
blocks.append(block)
n -= len(block)
return b''.join(blocks)
raise Return(b''.join(blocks))
if _PY35:
@coroutine

View File

@@ -8,13 +8,16 @@ from . import futures
from . import protocols
from . import streams
from . import tasks
from .coroutines import coroutine
from .coroutines import coroutine, From, Return
from .py33_exceptions import (BrokenPipeError, ConnectionResetError,
ProcessLookupError)
from .log import logger
PIPE = subprocess.PIPE
STDOUT = subprocess.STDOUT
DEVNULL = subprocess.DEVNULL
if hasattr(subprocess, 'DEVNULL'):
DEVNULL = subprocess.DEVNULL
class SubprocessStreamProtocol(streams.FlowControlMixin,
@@ -22,7 +25,7 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
"""Like StreamReaderProtocol, but for a subprocess."""
def __init__(self, limit, loop):
super().__init__(loop=loop)
super(SubprocessStreamProtocol, self).__init__(loop=loop)
self._limit = limit
self.stdin = self.stdout = self.stderr = None
self._transport = None
@@ -115,7 +118,8 @@ class Process:
"""Wait until the process exit and return the process return code.
This method is a coroutine."""
return (yield from self._transport._wait())
return_code = yield From(self._transport._wait())
raise Return(return_code)
def send_signal(self, signal):
self._transport.send_signal(signal)
@@ -134,7 +138,7 @@ class Process:
logger.debug('%r communicate: feed stdin (%s bytes)',
self, len(input))
try:
yield from self.stdin.drain()
yield From(self.stdin.drain())
except (BrokenPipeError, ConnectionResetError) as exc:
# communicate() ignores BrokenPipeError and ConnectionResetError
if debug:
@@ -159,12 +163,12 @@ class Process:
if self._loop.get_debug():
name = 'stdout' if fd == 1 else 'stderr'
logger.debug('%r communicate: read %s', self, name)
output = yield from stream.read()
output = yield From(stream.read())
if self._loop.get_debug():
name = 'stdout' if fd == 1 else 'stderr'
logger.debug('%r communicate: close %s', self, name)
transport.close()
return output
raise Return(output)
@coroutine
def communicate(self, input=None):
@@ -180,36 +184,43 @@ class Process:
stderr = self._read_stream(2)
else:
stderr = self._noop()
stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr,
loop=self._loop)
yield from self.wait()
return (stdout, stderr)
stdin, stdout, stderr = yield From(tasks.gather(stdin, stdout, stderr,
loop=self._loop))
yield From(self.wait())
raise Return(stdout, stderr)
@coroutine
def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None,
loop=None, limit=streams._DEFAULT_LIMIT, **kwds):
def create_subprocess_shell(cmd, **kwds):
stdin = kwds.pop('stdin', None)
stdout = kwds.pop('stdout', None)
stderr = kwds.pop('stderr', None)
loop = kwds.pop('loop', None)
limit = kwds.pop('limit', streams._DEFAULT_LIMIT)
if loop is None:
loop = events.get_event_loop()
protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
loop=loop)
transport, protocol = yield from loop.subprocess_shell(
transport, protocol = yield From(loop.subprocess_shell(
protocol_factory,
cmd, stdin=stdin, stdout=stdout,
stderr=stderr, **kwds)
return Process(transport, protocol, loop)
stderr=stderr, **kwds))
raise Return(Process(transport, protocol, loop))
@coroutine
def create_subprocess_exec(program, *args, stdin=None, stdout=None,
stderr=None, loop=None,
limit=streams._DEFAULT_LIMIT, **kwds):
def create_subprocess_exec(program, *args, **kwds):
stdin = kwds.pop('stdin', None)
stdout = kwds.pop('stdout', None)
stderr = kwds.pop('stderr', None)
loop = kwds.pop('loop', None)
limit = kwds.pop('limit', streams._DEFAULT_LIMIT)
if loop is None:
loop = events.get_event_loop()
protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
loop=loop)
transport, protocol = yield from loop.subprocess_exec(
transport, protocol = yield From(loop.subprocess_exec(
protocol_factory,
program, *args,
stdin=stdin, stdout=stdout,
stderr=stderr, **kwds)
return Process(transport, protocol, loop)
stderr=stderr, **kwds))
raise Return(Process(transport, protocol, loop))

View File

@@ -14,16 +14,30 @@ import sys
import types
import traceback
import warnings
import weakref
try:
from weakref import WeakSet
except ImportError:
# Python 2.6
from .py27_weakrefset import WeakSet
from . import compat
from . import coroutines
from . import events
from . import executor
from . import futures
from .coroutines import coroutine
from .locks import Lock, Condition, Semaphore, _ContextManager
from .coroutines import coroutine, From, Return
_PY34 = (sys.version_info >= (3, 4))
@coroutine
def _lock_coroutine(lock):
yield From(lock.acquire())
raise Return(_ContextManager(lock))
class Task(futures.Future):
"""A coroutine wrapped in a Future."""
@@ -37,7 +51,7 @@ class Task(futures.Future):
# must be _wakeup().
# Weak set containing all tasks alive.
_all_tasks = weakref.WeakSet()
_all_tasks = WeakSet()
# Dictionary containing tasks that are currently active in
# all running event loops. {EventLoop: Task}
@@ -67,11 +81,11 @@ class Task(futures.Future):
"""
if loop is None:
loop = events.get_event_loop()
return {t for t in cls._all_tasks if t._loop is loop}
return set(t for t in cls._all_tasks if t._loop is loop)
def __init__(self, coro, *, loop=None):
def __init__(self, coro, loop=None):
assert coroutines.iscoroutine(coro), repr(coro)
super().__init__(loop=loop)
super(Task, self).__init__(loop=loop)
if self._source_traceback:
del self._source_traceback[-1]
self._coro = coro
@@ -96,7 +110,7 @@ class Task(futures.Future):
futures.Future.__del__(self)
def _repr_info(self):
info = super()._repr_info()
info = super(Task, self)._repr_info()
if self._must_cancel:
# replace status
@@ -109,7 +123,7 @@ class Task(futures.Future):
info.insert(2, 'wait_for=%r' % self._fut_waiter)
return info
def get_stack(self, *, limit=None):
def get_stack(self, limit=None):
"""Return the list of stack frames for this task's coroutine.
If the coroutine is not done, this returns the stack where it is
@@ -152,7 +166,7 @@ class Task(futures.Future):
tb = tb.tb_next
return frames
def print_stack(self, *, limit=None, file=None):
def print_stack(self, limit=None, file=None):
"""Print the stack or traceback for this task's coroutine.
This produces output similar to that of the traceback module,
@@ -219,9 +233,9 @@ class Task(futures.Future):
self._must_cancel = True
return True
def _step(self, value=None, exc=None):
def _step(self, value=None, exc=None, exc_tb=None):
assert not self.done(), \
'_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc)
'_step(): already done: {0!r}, {1!r}, {2!r}'.format(self, value, exc)
if self._must_cancel:
if not isinstance(exc, futures.CancelledError):
exc = futures.CancelledError()
@@ -229,6 +243,10 @@ class Task(futures.Future):
coro = self._coro
self._fut_waiter = None
if exc_tb is not None:
init_exc = exc
else:
init_exc = None
self.__class__._current_tasks[self._loop] = self
# Call either coro.throw(exc) or coro.send(value).
try:
@@ -237,71 +255,104 @@ class Task(futures.Future):
else:
result = coro.send(value)
except StopIteration as exc:
self.set_result(exc.value)
except futures.CancelledError as exc:
super().cancel() # I.e., Future.cancel(self).
except Exception as exc:
self.set_exception(exc)
except BaseException as exc:
self.set_exception(exc)
raise
else:
if isinstance(result, futures.Future):
# Yielded Future must come from Future.__iter__().
if result._blocking:
result._blocking = False
result.add_done_callback(self._wakeup)
self._fut_waiter = result
if self._must_cancel:
if self._fut_waiter.cancel():
self._must_cancel = False
if compat.PY33:
# asyncio Task object? get the result of the coroutine
result = exc.value
else:
if isinstance(exc, Return):
exc.raised = True
result = exc.value
else:
self._loop.call_soon(
self._step, None,
RuntimeError(
'yield was used instead of yield from '
'in task {!r} with {!r}'.format(self, result)))
result = None
self.set_result(result)
except futures.CancelledError as exc:
super(Task, self).cancel() # I.e., Future.cancel(self).
except BaseException as exc:
if exc is init_exc:
self._set_exception_with_tb(exc, exc_tb)
exc_tb = None
else:
self.set_exception(exc)
if not isinstance(exc, Exception):
# reraise BaseException
raise
else:
if coroutines._DEBUG:
if not coroutines._coroutine_at_yield_from(self._coro):
# trollius coroutine must "yield From(...)"
if not isinstance(result, coroutines.FromWrapper):
self._loop.call_soon(
self._step, None,
RuntimeError("yield used without From"))
return
result = result.obj
else:
# asyncio coroutine using "yield from ..."
if isinstance(result, coroutines.FromWrapper):
result = result.obj
elif isinstance(result, coroutines.FromWrapper):
result = result.obj
if coroutines.iscoroutine(result):
# "yield coroutine" creates a task, the current task
# will wait until the new task is done
result = self._loop.create_task(result)
# FIXME: faster check. common base class? hasattr?
elif isinstance(result, (Lock, Condition, Semaphore)):
coro = _lock_coroutine(result)
result = self._loop.create_task(coro)
if isinstance(result, futures._FUTURE_CLASSES):
# Yielded Future must come from Future.__iter__().
result.add_done_callback(self._wakeup)
self._fut_waiter = result
if self._must_cancel:
if self._fut_waiter.cancel():
self._must_cancel = False
elif result is None:
# Bare yield relinquishes control for one event loop iteration.
self._loop.call_soon(self._step)
elif inspect.isgenerator(result):
# Yielding a generator is just wrong.
self._loop.call_soon(
self._step, None,
RuntimeError(
'yield was used instead of yield from for '
'generator in task {!r} with {}'.format(
self, result)))
else:
# Yielding something else is an error.
self._loop.call_soon(
self._step, None,
RuntimeError(
'Task got bad yield: {!r}'.format(result)))
'Task got bad yield: {0!r}'.format(result)))
finally:
self.__class__._current_tasks.pop(self._loop)
self = None # Needed to break cycles when an exception occurs.
def _wakeup(self, future):
try:
value = future.result()
except Exception as exc:
# This may also be a cancellation.
self._step(None, exc)
if (future._state == futures._FINISHED
and future._exception is not None):
# Get the traceback before calling exception(), because calling
# the exception() method clears the traceback
exc_tb = future._get_exception_tb()
exc = future.exception()
self._step(None, exc, exc_tb)
exc_tb = None
else:
self._step(value, None)
try:
value = future.result()
except Exception as exc:
# This may also be a cancellation.
self._step(None, exc)
else:
self._step(value, None)
self = None # Needed to break cycles when an exception occurs.
# wait() and as_completed() similar to those in PEP 3148.
FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED
FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION
ALL_COMPLETED = concurrent.futures.ALL_COMPLETED
# Export symbols in trollius.tasks for compatibility with asyncio
FIRST_COMPLETED = executor.FIRST_COMPLETED
FIRST_EXCEPTION = executor.FIRST_EXCEPTION
ALL_COMPLETED = executor.ALL_COMPLETED
@coroutine
def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED):
def wait(fs, loop=None, timeout=None, return_when=ALL_COMPLETED):
"""Wait for the Futures and coroutines given by fs to complete.
The sequence futures must not be empty.
@@ -312,24 +363,25 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED):
Usage:
done, pending = yield from asyncio.wait(fs)
done, pending = yield From(asyncio.wait(fs))
Note: This does not raise TimeoutError! Futures that aren't done
when the timeout occurs are returned in the second set.
"""
if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs):
if isinstance(fs, futures._FUTURE_CLASSES) or coroutines.iscoroutine(fs):
raise TypeError("expect a list of futures, not %s" % type(fs).__name__)
if not fs:
raise ValueError('Set of coroutines/Futures is empty.')
if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED):
raise ValueError('Invalid return_when value: {}'.format(return_when))
raise ValueError('Invalid return_when value: {0}'.format(return_when))
if loop is None:
loop = events.get_event_loop()
fs = {ensure_future(f, loop=loop) for f in set(fs)}
fs = set(ensure_future(f, loop=loop) for f in set(fs))
return (yield from _wait(fs, timeout, return_when, loop))
result = yield From(_wait(fs, timeout, return_when, loop))
raise Return(result)
def _release_waiter(waiter, *args):
@@ -338,7 +390,7 @@ def _release_waiter(waiter, *args):
@coroutine
def wait_for(fut, timeout, *, loop=None):
def wait_for(fut, timeout, loop=None):
"""Wait for the single Future or coroutine to complete, with timeout.
Coroutine will be wrapped in Task.
@@ -355,7 +407,8 @@ def wait_for(fut, timeout, *, loop=None):
loop = events.get_event_loop()
if timeout is None:
return (yield from fut)
result = yield From(fut)
raise Return(result)
waiter = futures.Future(loop=loop)
timeout_handle = loop.call_later(timeout, _release_waiter, waiter)
@@ -367,14 +420,14 @@ def wait_for(fut, timeout, *, loop=None):
try:
# wait until the future completes or the timeout
try:
yield from waiter
yield From(waiter)
except futures.CancelledError:
fut.remove_done_callback(cb)
fut.cancel()
raise
if fut.done():
return fut.result()
raise Return(fut.result())
else:
fut.remove_done_callback(cb)
fut.cancel()
@@ -394,12 +447,11 @@ def _wait(fs, timeout, return_when, loop):
timeout_handle = None
if timeout is not None:
timeout_handle = loop.call_later(timeout, _release_waiter, waiter)
counter = len(fs)
non_local = {'counter': len(fs)}
def _on_completion(f):
nonlocal counter
counter -= 1
if (counter <= 0 or
non_local['counter'] -= 1
if (non_local['counter'] <= 0 or
return_when == FIRST_COMPLETED or
return_when == FIRST_EXCEPTION and (not f.cancelled() and
f.exception() is not None)):
@@ -412,7 +464,7 @@ def _wait(fs, timeout, return_when, loop):
f.add_done_callback(_on_completion)
try:
yield from waiter
yield From(waiter)
finally:
if timeout_handle is not None:
timeout_handle.cancel()
@@ -424,11 +476,11 @@ def _wait(fs, timeout, return_when, loop):
done.add(f)
else:
pending.add(f)
return done, pending
raise Return(done, pending)
# This is *not* a @coroutine! It is just an iterator (yielding Futures).
def as_completed(fs, *, loop=None, timeout=None):
def as_completed(fs, loop=None, timeout=None):
"""Return an iterator whose values are coroutines.
When waiting for the yielded coroutines you'll get the results (or
@@ -438,18 +490,18 @@ def as_completed(fs, *, loop=None, timeout=None):
This differs from PEP 3148; the proper way to use this is:
for f in as_completed(fs):
result = yield from f # The 'yield from' may raise.
result = yield From(f) # The 'yield' may raise.
# Use result.
If a timeout is specified, the 'yield from' will raise
If a timeout is specified, the 'yield' will raise
TimeoutError when the timeout occurs before all Futures are done.
Note: The futures 'f' are not necessarily members of fs.
"""
if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs):
if isinstance(fs, futures._FUTURE_CLASSES) or coroutines.iscoroutine(fs):
raise TypeError("expect a list of futures, not %s" % type(fs).__name__)
loop = loop if loop is not None else events.get_event_loop()
todo = {ensure_future(f, loop=loop) for f in set(fs)}
todo = set(ensure_future(f, loop=loop) for f in set(fs))
from .queues import Queue # Import here to avoid circular import problem.
done = Queue(loop=loop)
timeout_handle = None
@@ -470,11 +522,11 @@ def as_completed(fs, *, loop=None, timeout=None):
@coroutine
def _wait_for_one():
f = yield from done.get()
f = yield From(done.get())
if f is None:
# Dummy value from _on_timeout().
raise futures.TimeoutError
return f.result() # May raise f.exception().
raise Return(f.result()) # May raise f.exception().
for f in todo:
f.add_done_callback(_on_completion)
@@ -485,18 +537,19 @@ def as_completed(fs, *, loop=None, timeout=None):
@coroutine
def sleep(delay, result=None, *, loop=None):
def sleep(delay, result=None, loop=None):
"""Coroutine that completes after a given time (in seconds)."""
future = futures.Future(loop=loop)
h = future._loop.call_later(delay,
future._set_result_unless_cancelled, result)
try:
return (yield from future)
result = yield From(future)
raise Return(result)
finally:
h.cancel()
def async(coro_or_future, *, loop=None):
def async(coro_or_future, loop=None):
"""Wrap a coroutine in a future.
If the argument is a Future, it is returned directly.
@@ -515,7 +568,10 @@ def ensure_future(coro_or_future, *, loop=None):
If the argument is a Future, it is returned directly.
"""
if isinstance(coro_or_future, futures.Future):
# FIXME: only check if coroutines._DEBUG is True?
if isinstance(coro_or_future, coroutines.FromWrapper):
coro_or_future = coro_or_future.obj
if isinstance(coro_or_future, futures._FUTURE_CLASSES):
if loop is not None and loop is not coro_or_future._loop:
raise ValueError('loop argument must agree with Future')
return coro_or_future
@@ -538,8 +594,8 @@ class _GatheringFuture(futures.Future):
cancelled.
"""
def __init__(self, children, *, loop=None):
super().__init__(loop=loop)
def __init__(self, children, loop=None):
super(_GatheringFuture, self).__init__(loop=loop)
self._children = children
def cancel(self):
@@ -550,7 +606,7 @@ class _GatheringFuture(futures.Future):
return True
def gather(*coros_or_futures, loop=None, return_exceptions=False):
def gather(*coros_or_futures, **kw):
"""Return a future aggregating results from the given coroutines
or futures.
@@ -570,6 +626,11 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
prevent the cancellation of one child to cause other children to
be cancelled.)
"""
loop = kw.pop('loop', None)
return_exceptions = kw.pop('return_exceptions', False)
if kw:
raise TypeError("unexpected keyword")
if not coros_or_futures:
outer = futures.Future(loop=loop)
outer.set_result([])
@@ -577,7 +638,7 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
arg_to_fut = {}
for arg in set(coros_or_futures):
if not isinstance(arg, futures.Future):
if not isinstance(arg, futures._FUTURE_CLASSES):
fut = ensure_future(arg, loop=loop)
if loop is None:
loop = fut._loop
@@ -595,11 +656,10 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
children = [arg_to_fut[arg] for arg in coros_or_futures]
nchildren = len(children)
outer = _GatheringFuture(children, loop=loop)
nfinished = 0
non_local = {'nfinished': 0}
results = [None] * nchildren
def _done_callback(i, fut):
nonlocal nfinished
if outer.done():
if not fut.cancelled():
# Mark exception retrieved.
@@ -619,8 +679,8 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
else:
res = fut._result
results[i] = res
nfinished += 1
if nfinished == nchildren:
non_local['nfinished'] += 1
if non_local['nfinished'] == nchildren:
outer.set_result(results)
for i, fut in enumerate(children):
@@ -628,16 +688,16 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
return outer
def shield(arg, *, loop=None):
def shield(arg, loop=None):
"""Wait for a future, shielding it from cancellation.
The statement
res = yield from shield(something())
res = yield From(shield(something()))
is exactly equivalent to the statement
res = yield from something()
res = yield From(something())
*except* that if the coroutine containing it is cancelled, the
task running in something() is not cancelled. From the POV of
@@ -650,7 +710,7 @@ def shield(arg, *, loop=None):
you can combine shield() with a try/except clause, as follows:
try:
res = yield from shield(something())
res = yield From(shield(something()))
except CancelledError:
res = None
"""

View File

@@ -4,6 +4,7 @@
# Ignore symbol TEST_HOME_DIR: test_events works without it
from __future__ import absolute_import
import functools
import gc
import os
@@ -14,6 +15,7 @@ import subprocess
import sys
import time
from trollius import test_utils
# A constant likely larger than the underlying OS pipe buffer size, to
# make writes blocking.
@@ -39,7 +41,9 @@ def _assert_python(expected_success, *args, **env_vars):
isolated = env_vars.pop('__isolated')
else:
isolated = not env_vars
cmd_line = [sys.executable, '-X', 'faulthandler']
cmd_line = [sys.executable]
if sys.version_info >= (3, 3):
cmd_line.extend(('-X', 'faulthandler'))
if isolated and sys.version_info >= (3, 4):
# isolated mode: ignore Python environment variables, ignore user
# site-packages, and don't add the current directory to sys.path
@@ -248,7 +252,7 @@ def requires_mac_ver(*min_version):
else:
if version < min_version:
min_version_txt = '.'.join(map(str, min_version))
raise unittest.SkipTest(
raise test_utils.SkipTest(
"Mac OS X %s or higher required, not %s"
% (min_version_txt, version_txt))
return func(*args, **kw)
@@ -275,7 +279,7 @@ def _requires_unix_version(sysname, min_version):
else:
if version < min_version:
min_version_txt = '.'.join(map(str, min_version))
raise unittest.SkipTest(
raise test_utils.SkipTest(
"%s version %s or higher required, not %s"
% (sysname, min_version_txt, version_txt))
return func(*args, **kw)
@@ -300,9 +304,6 @@ except ImportError:
# Use test.script_helper if available
try:
from test.support.script_helper import assert_python_ok
from test.script_helper import assert_python_ok
except ImportError:
try:
from test.script_helper import assert_python_ok
except ImportError:
pass
pass

View File

@@ -7,20 +7,32 @@ import logging
import os
import re
import socket
import socketserver
import sys
import tempfile
import threading
import time
import unittest
from unittest import mock
from http.server import HTTPServer
from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
try:
import socketserver
from http.server import HTTPServer
except ImportError:
# Python 2
import SocketServer as socketserver
from BaseHTTPServer import HTTPServer
try:
from unittest import mock
except ImportError:
# Python < 3.3
import mock
try:
import ssl
from .py3_ssl import SSLContext, wrap_socket
except ImportError: # pragma: no cover
# SSL support disabled in Python
ssl = None
from . import base_events
@@ -37,27 +49,116 @@ if sys.platform == 'win32': # pragma: no cover
else:
from socket import socketpair # pragma: no cover
try:
import unittest
skipIf = unittest.skipIf
skipUnless = unittest.skipUnless
SkipTest = unittest.SkipTest
_TestCase = unittest.TestCase
except AttributeError:
# Python 2.6: use the backported unittest module called "unittest2"
import unittest2
skipIf = unittest2.skipIf
skipUnless = unittest2.skipUnless
SkipTest = unittest2.SkipTest
_TestCase = unittest2.TestCase
if not hasattr(_TestCase, 'assertRaisesRegex'):
class _BaseTestCaseContext:
def __init__(self, test_case):
self.test_case = test_case
def _raiseFailure(self, standardMsg):
msg = self.test_case._formatMessage(self.msg, standardMsg)
raise self.test_case.failureException(msg)
class _AssertRaisesBaseContext(_BaseTestCaseContext):
def __init__(self, expected, test_case, callable_obj=None,
expected_regex=None):
_BaseTestCaseContext.__init__(self, test_case)
self.expected = expected
self.test_case = test_case
if callable_obj is not None:
try:
self.obj_name = callable_obj.__name__
except AttributeError:
self.obj_name = str(callable_obj)
else:
self.obj_name = None
if isinstance(expected_regex, (bytes, str)):
expected_regex = re.compile(expected_regex)
self.expected_regex = expected_regex
self.msg = None
def handle(self, name, callable_obj, args, kwargs):
"""
If callable_obj is None, assertRaises/Warns is being used as a
context manager, so check for a 'msg' kwarg and return self.
If callable_obj is not None, call it passing args and kwargs.
"""
if callable_obj is None:
self.msg = kwargs.pop('msg', None)
return self
with self:
callable_obj(*args, **kwargs)
class _AssertRaisesContext(_AssertRaisesBaseContext):
"""A context manager used to implement TestCase.assertRaises* methods."""
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, tb):
if exc_type is None:
try:
exc_name = self.expected.__name__
except AttributeError:
exc_name = str(self.expected)
if self.obj_name:
self._raiseFailure("{0} not raised by {1}".format(exc_name,
self.obj_name))
else:
self._raiseFailure("{0} not raised".format(exc_name))
if not issubclass(exc_type, self.expected):
# let unexpected exceptions pass through
return False
self.exception = exc_value
if self.expected_regex is None:
return True
expected_regex = self.expected_regex
if not expected_regex.search(str(exc_value)):
self._raiseFailure('"{0}" does not match "{1}"'.format(
expected_regex.pattern, str(exc_value)))
return True
def dummy_ssl_context():
if ssl is None:
return None
else:
return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
return SSLContext(ssl.PROTOCOL_SSLv23)
def run_briefly(loop):
def run_briefly(loop, steps=1):
@coroutine
def once():
pass
gen = once()
t = loop.create_task(gen)
# Don't log a warning if the task is not done after run_until_complete().
# It occurs if the loop is stopped or if a task raises a BaseException.
t._log_destroy_pending = False
try:
loop.run_until_complete(t)
finally:
gen.close()
for step in range(steps):
gen = once()
t = loop.create_task(gen)
# Don't log a warning if the task is not done after run_until_complete().
# It occurs if the loop is stopped or if a task raises a BaseException.
t._log_destroy_pending = False
try:
loop.run_until_complete(t)
finally:
gen.close()
def run_until(loop, pred, timeout=30):
@@ -89,12 +190,12 @@ class SilentWSGIRequestHandler(WSGIRequestHandler):
pass
class SilentWSGIServer(WSGIServer):
class SilentWSGIServer(WSGIServer, object):
request_timeout = 2
def get_request(self):
request, client_addr = super().get_request()
request, client_addr = super(SilentWSGIServer, self).get_request()
request.settimeout(self.request_timeout)
return request, client_addr
@@ -115,10 +216,10 @@ class SSLWSGIServerMixin:
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
ssock = wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
@@ -131,7 +232,7 @@ class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
pass
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
def _run_test_server(address, use_ssl, server_cls, server_ssl_cls):
def app(environ, start_response):
status = '200 OK'
@@ -158,7 +259,7 @@ def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
if hasattr(socket, 'AF_UNIX'):
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer, object):
def server_bind(self):
socketserver.UnixStreamServer.server_bind(self)
@@ -166,7 +267,7 @@ if hasattr(socket, 'AF_UNIX'):
self.server_port = 80
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
class UnixWSGIServer(UnixHTTPServer, WSGIServer, object):
request_timeout = 2
@@ -175,7 +276,7 @@ if hasattr(socket, 'AF_UNIX'):
self.setup_environ()
def get_request(self):
request, client_addr = super().get_request()
request, client_addr = super(UnixWSGIServer, self).get_request()
request.settimeout(self.request_timeout)
# Code in the stdlib expects that get_request
# will return a socket and a tuple (host, port).
@@ -214,18 +315,20 @@ if hasattr(socket, 'AF_UNIX'):
@contextlib.contextmanager
def run_test_unix_server(*, use_ssl=False):
def run_test_unix_server(use_ssl=False):
with unix_socket_path() as path:
yield from _run_test_server(address=path, use_ssl=use_ssl,
server_cls=SilentUnixWSGIServer,
server_ssl_cls=UnixSSLWSGIServer)
for item in _run_test_server(address=path, use_ssl=use_ssl,
server_cls=SilentUnixWSGIServer,
server_ssl_cls=UnixSSLWSGIServer):
yield item
@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
server_cls=SilentWSGIServer,
server_ssl_cls=SSLWSGIServer)
def run_test_server(host='127.0.0.1', port=0, use_ssl=False):
for item in _run_test_server(address=(host, port), use_ssl=use_ssl,
server_cls=SilentWSGIServer,
server_ssl_cls=SSLWSGIServer):
yield item
def make_test_protocol(base):
@@ -278,7 +381,7 @@ class TestLoop(base_events.BaseEventLoop):
"""
def __init__(self, gen=None):
super().__init__()
super(TestLoop, self).__init__()
if gen is None:
def gen():
@@ -307,7 +410,7 @@ class TestLoop(base_events.BaseEventLoop):
self._time += advance
def close(self):
super().close()
super(TestLoop, self).close()
if self._check_on_close:
try:
self._gen.send(0)
@@ -328,11 +431,11 @@ class TestLoop(base_events.BaseEventLoop):
return False
def assert_reader(self, fd, callback, *args):
assert fd in self.readers, 'fd {} is not registered'.format(fd)
assert fd in self.readers, 'fd {0} is not registered'.format(fd)
handle = self.readers[fd]
assert handle._callback == callback, '{!r} != {!r}'.format(
assert handle._callback == callback, '{0!r} != {1!r}'.format(
handle._callback, callback)
assert handle._args == args, '{!r} != {!r}'.format(
assert handle._args == args, '{0!r} != {1!r}'.format(
handle._args, args)
def add_writer(self, fd, callback, *args):
@@ -347,11 +450,11 @@ class TestLoop(base_events.BaseEventLoop):
return False
def assert_writer(self, fd, callback, *args):
assert fd in self.writers, 'fd {} is not registered'.format(fd)
assert fd in self.writers, 'fd {0} is not registered'.format(fd)
handle = self.writers[fd]
assert handle._callback == callback, '{!r} != {!r}'.format(
assert handle._callback == callback, '{0!r} != {1!r}'.format(
handle._callback, callback)
assert handle._args == args, '{!r} != {!r}'.format(
assert handle._args == args, '{0!r} != {1!r}'.format(
handle._args, args)
def reset_counters(self):
@@ -359,7 +462,7 @@ class TestLoop(base_events.BaseEventLoop):
self.remove_writer_count = collections.defaultdict(int)
def _run_once(self):
super()._run_once()
super(TestLoop, self)._run_once()
for when in self._timers:
advance = self._gen.send(when)
self.advance_time(advance)
@@ -367,7 +470,7 @@ class TestLoop(base_events.BaseEventLoop):
def call_at(self, when, callback, *args):
self._timers.append(when)
return super().call_at(when, callback, *args)
return super(TestLoop, self).call_at(when, callback, *args)
def _process_events(self, event_list):
return
@@ -400,8 +503,8 @@ def get_function_source(func):
return source
class TestCase(unittest.TestCase):
def set_event_loop(self, loop, *, cleanup=True):
class TestCase(_TestCase):
def set_event_loop(self, loop, cleanup=True):
assert loop is not None
# ensure that the event loop is passed explicitly in asyncio
events.set_event_loop(None)
@@ -420,6 +523,48 @@ class TestCase(unittest.TestCase):
# in an except block of a generator
self.assertEqual(sys.exc_info(), (None, None, None))
if not hasattr(_TestCase, 'assertRaisesRegex'):
def assertRaisesRegex(self, expected_exception, expected_regex,
callable_obj=None, *args, **kwargs):
"""Asserts that the message in a raised exception matches a regex.
Args:
expected_exception: Exception class expected to be raised.
expected_regex: Regex (re pattern object or string) expected
to be found in error message.
callable_obj: Function to be called.
msg: Optional message used in case of failure. Can only be used
when assertRaisesRegex is used as a context manager.
args: Extra args.
kwargs: Extra kwargs.
"""
context = _AssertRaisesContext(expected_exception, self, callable_obj,
expected_regex)
return context.handle('assertRaisesRegex', callable_obj, args, kwargs)
if not hasattr(_TestCase, 'assertRegex'):
def assertRegex(self, text, expected_regex, msg=None):
"""Fail the test unless the text matches the regular expression."""
if isinstance(expected_regex, (str, bytes)):
assert expected_regex, "expected_regex must not be empty."
expected_regex = re.compile(expected_regex)
if not expected_regex.search(text):
msg = msg or "Regex didn't match"
msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text)
raise self.failureException(msg)
def check_soure_traceback(self, source_traceback, lineno_delta):
frame = sys._getframe(1)
filename = frame.f_code.co_filename
lineno = frame.f_lineno + lineno_delta
name = frame.f_code.co_name
self.assertIsInstance(source_traceback, list)
self.assertEqual(source_traceback[-1][:3],
(filename,
lineno,
name))
@contextlib.contextmanager
def disable_logger():

View File

@@ -1,6 +1,7 @@
"""Abstract Transport class."""
import sys
from .compat import flatten_bytes
_PY34 = sys.version_info >= (3, 4)
@@ -9,7 +10,7 @@ __all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport',
]
class BaseTransport:
class BaseTransport(object):
"""Base class for transports."""
def __init__(self, extra=None):
@@ -94,12 +95,8 @@ class WriteTransport(BaseTransport):
The default implementation concatenates the arguments and
calls write() on the result.
"""
if not _PY34:
# In Python 3.3, bytes.join() doesn't handle memoryview.
list_of_data = (
bytes(data) if isinstance(data, memoryview) else data
for data in list_of_data)
self.write(b''.join(list_of_data))
data = map(flatten_bytes, list_of_data)
self.write(b''.join(data))
def write_eof(self):
"""Close the write end after flushing buffered data.
@@ -230,7 +227,7 @@ class _FlowControlMixin(Transport):
override set_write_buffer_limits() (e.g. to specify different
defaults).
The subclass constructor must call super().__init__(extra). This
The subclass constructor must call super(Class, self).__init__(extra). This
will call set_write_buffer_limits().
The user may call set_write_buffer_limits() and
@@ -239,7 +236,7 @@ class _FlowControlMixin(Transport):
"""
def __init__(self, extra=None, loop=None):
super().__init__(extra)
super(_FlowControlMixin, self).__init__(extra)
assert loop is not None
self._loop = loop
self._protocol_paused = False

View File

@@ -1,4 +1,5 @@
"""Selector event loop for Unix with signal handling."""
from __future__ import absolute_import
import errno
import os
@@ -13,6 +14,7 @@ import warnings
from . import base_events
from . import base_subprocess
from . import compat
from . import constants
from . import coroutines
from . import events
@@ -20,8 +22,13 @@ from . import futures
from . import selector_events
from . import selectors
from . import transports
from .coroutines import coroutine
from .compat import flatten_bytes
from .coroutines import coroutine, From, Return
from .log import logger
from .py33_exceptions import (
reraise, wrap_error,
BlockingIOError, BrokenPipeError, ConnectionResetError,
InterruptedError, ChildProcessError)
__all__ = ['SelectorEventLoop',
@@ -33,9 +40,10 @@ if sys.platform == 'win32': # pragma: no cover
raise ImportError('Signals are not really supported on Windows')
def _sighandler_noop(signum, frame):
"""Dummy signal handler."""
pass
if compat.PY33:
def _sighandler_noop(signum, frame):
"""Dummy signal handler."""
pass
class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
@@ -45,23 +53,27 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
"""
def __init__(self, selector=None):
super().__init__(selector)
super(_UnixSelectorEventLoop, self).__init__(selector)
self._signal_handlers = {}
def _socketpair(self):
return socket.socketpair()
def close(self):
super().close()
super(_UnixSelectorEventLoop, self).close()
for sig in list(self._signal_handlers):
self.remove_signal_handler(sig)
def _process_self_data(self, data):
for signum in data:
if not signum:
# ignore null bytes written by _write_to_self()
continue
self._handle_signal(signum)
# On Python <= 3.2, the C signal handler of Python writes a null byte into
# the wakeup file descriptor. We cannot retrieve the signal numbers from
# the file descriptor.
if compat.PY33:
def _process_self_data(self, data):
for signum in data:
if not signum:
# ignore null bytes written by _write_to_self()
continue
self._handle_signal(signum)
def add_signal_handler(self, sig, callback, *args):
"""Add a handler for a signal. UNIX only.
@@ -88,14 +100,30 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
self._signal_handlers[sig] = handle
try:
# Register a dummy signal handler to ask Python to write the signal
# number in the wakup file descriptor. _process_self_data() will
# read signal numbers from this file descriptor to handle signals.
signal.signal(sig, _sighandler_noop)
if compat.PY33:
# On Python 3.3 and newer, the C signal handler writes the
# signal number into the wakeup file descriptor and then calls
# Py_AddPendingCall() to schedule the Python signal handler.
#
# Register a dummy signal handler to ask Python to write the
# signal number into the wakup file descriptor.
# _process_self_data() will read signal numbers from this file
# descriptor to handle signals.
signal.signal(sig, _sighandler_noop)
else:
# On Python 3.2 and older, the C signal handler first calls
# Py_AddPendingCall() to schedule the Python signal handler,
# and then write a null byte into the wakeup file descriptor.
signal.signal(sig, self._handle_signal)
# Set SA_RESTART to limit EINTR occurrences.
signal.siginterrupt(sig, False)
except OSError as exc:
except (RuntimeError, OSError) as exc:
# On Python 2, signal.signal(signal.SIGKILL, signal.SIG_IGN) raises
# RuntimeError(22, 'Invalid argument'). On Python 3,
# OSError(22, 'Invalid argument') is raised instead.
exc_type, exc_value, tb = sys.exc_info()
del self._signal_handlers[sig]
if not self._signal_handlers:
try:
@@ -103,12 +131,12 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
except (ValueError, OSError) as nexc:
logger.info('set_wakeup_fd(-1) failed: %s', nexc)
if exc.errno == errno.EINVAL:
raise RuntimeError('sig {} cannot be caught'.format(sig))
if isinstance(exc, RuntimeError) or exc.errno == errno.EINVAL:
raise RuntimeError('sig {0} cannot be caught'.format(sig))
else:
raise
reraise(exc_type, exc_value, tb)
def _handle_signal(self, sig):
def _handle_signal(self, sig, frame=None):
"""Internal helper that is the actual signal handler."""
handle = self._signal_handlers.get(sig)
if handle is None:
@@ -138,7 +166,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
signal.signal(sig, handler)
except OSError as exc:
if exc.errno == errno.EINVAL:
raise RuntimeError('sig {} cannot be caught'.format(sig))
raise RuntimeError('sig {0} cannot be caught'.format(sig))
else:
raise
@@ -157,11 +185,11 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
Raise RuntimeError if there is a problem setting up the handler.
"""
if not isinstance(sig, int):
raise TypeError('sig must be an int, not {!r}'.format(sig))
raise TypeError('sig must be an int, not {0!r}'.format(sig))
if not (1 <= sig < signal.NSIG):
raise ValueError(
'sig {} out of range(1, {})'.format(sig, signal.NSIG))
'sig {0} out of range(1, {1})'.format(sig, signal.NSIG))
def _make_read_pipe_transport(self, pipe, protocol, waiter=None,
extra=None):
@@ -185,7 +213,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
watcher.add_child_handler(transp.get_pid(),
self._child_watcher_callback, transp)
try:
yield from waiter
yield From(waiter)
except Exception as exc:
# Workaround CPython bug #23353: using yield/yield-from in an
# except block of a generator doesn't clear properly
@@ -196,16 +224,16 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
if err is not None:
transp.close()
yield from transp._wait()
yield From(transp._wait())
raise err
return transp
raise Return(transp)
def _child_watcher_callback(self, pid, returncode, transp):
self.call_soon_threadsafe(transp._process_exited, returncode)
@coroutine
def create_unix_connection(self, protocol_factory, path, *,
def create_unix_connection(self, protocol_factory, path,
ssl=None, sock=None,
server_hostname=None):
assert server_hostname is None or isinstance(server_hostname, str)
@@ -225,7 +253,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
try:
sock.setblocking(False)
yield from self.sock_connect(sock, path)
yield From(self.sock_connect(sock, path))
except:
sock.close()
raise
@@ -235,12 +263,12 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
raise ValueError('no path and sock were specified')
sock.setblocking(False)
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
return transport, protocol
transport, protocol = yield From(self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname))
raise Return(transport, protocol)
@coroutine
def create_unix_server(self, protocol_factory, path=None, *,
def create_unix_server(self, protocol_factory, path=None,
sock=None, backlog=100, ssl=None):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
@@ -254,13 +282,13 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
try:
sock.bind(path)
except OSError as exc:
except socket.error as exc:
sock.close()
if exc.errno == errno.EADDRINUSE:
# Let's improve the error message by adding
# with what exact address it occurs.
msg = 'Address {!r} is already in use'.format(path)
raise OSError(errno.EADDRINUSE, msg) from None
msg = 'Address {0!r} is already in use'.format(path)
raise OSError(errno.EADDRINUSE, msg)
else:
raise
except:
@@ -273,7 +301,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
if sock.family != socket.AF_UNIX:
raise ValueError(
'A UNIX Domain Socket was expected, got {!r}'.format(sock))
'A UNIX Domain Socket was expected, got {0!r}'.format(sock))
server = base_events.Server(self, [sock])
sock.listen(backlog)
@@ -283,6 +311,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
if hasattr(os, 'set_blocking'):
# Python 3.5 and newer
def _set_nonblocking(fd):
os.set_blocking(fd, False)
else:
@@ -299,7 +328,7 @@ class _UnixReadPipeTransport(transports.ReadTransport):
max_size = 256 * 1024 # max bytes we read in one event loop iteration
def __init__(self, loop, pipe, protocol, waiter=None, extra=None):
super().__init__(extra)
super(_UnixReadPipeTransport, self).__init__(extra)
self._extra['pipe'] = pipe
self._loop = loop
self._pipe = pipe
@@ -341,7 +370,7 @@ class _UnixReadPipeTransport(transports.ReadTransport):
def _read_ready(self):
try:
data = os.read(self._fileno, self.max_size)
data = wrap_error(os.read, self._fileno, self.max_size)
except (BlockingIOError, InterruptedError):
pass
except OSError as exc:
@@ -409,7 +438,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
transports.WriteTransport):
def __init__(self, loop, pipe, protocol, waiter=None, extra=None):
super().__init__(extra, loop)
super(_UnixWritePipeTransport, self).__init__(extra, loop)
self._extra['pipe'] = pipe
self._pipe = pipe
self._fileno = pipe.fileno()
@@ -475,9 +504,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
self._close()
def write(self, data):
assert isinstance(data, (bytes, bytearray, memoryview)), repr(data)
if isinstance(data, bytearray):
data = memoryview(data)
data = flatten_bytes(data)
if not data:
return
@@ -491,7 +518,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
if not self._buffer:
# Attempt to send it right away first.
try:
n = os.write(self._fileno, data)
n = wrap_error(os.write, self._fileno, data)
except (BlockingIOError, InterruptedError):
n = 0
except Exception as exc:
@@ -511,9 +538,9 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
data = b''.join(self._buffer)
assert data, 'Data should not be empty'
self._buffer.clear()
del self._buffer[:]
try:
n = os.write(self._fileno, data)
n = wrap_error(os.write, self._fileno, data)
except (BlockingIOError, InterruptedError):
self._buffer.append(data)
except Exception as exc:
@@ -582,7 +609,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
self._closing = True
if self._buffer:
self._loop.remove_writer(self._fileno)
self._buffer.clear()
del self._buffer[:]
self._loop.remove_reader(self._fileno)
self._loop.call_soon(self._call_connection_lost, exc)
@@ -633,11 +660,20 @@ class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport):
args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr,
universal_newlines=False, bufsize=bufsize, **kwargs)
if stdin_w is not None:
# Retrieve the file descriptor from stdin_w, stdin_w should not
# "own" the file descriptor anymore: closing stdin_fd file
# descriptor must close immediatly the file
stdin.close()
self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize)
if hasattr(stdin_w, 'detach'):
stdin_fd = stdin_w.detach()
self._proc.stdin = os.fdopen(stdin_fd, 'wb', bufsize)
else:
stdin_dup = os.dup(stdin_w.fileno())
stdin_w.close()
self._proc.stdin = os.fdopen(stdin_dup, 'wb', bufsize)
class AbstractChildWatcher:
class AbstractChildWatcher(object):
"""Abstract base class for monitoring child processes.
Objects derived from this class monitor a collection of subprocesses and
@@ -773,12 +809,12 @@ class SafeChildWatcher(BaseChildWatcher):
"""
def __init__(self):
super().__init__()
super(SafeChildWatcher, self).__init__()
self._callbacks = {}
def close(self):
self._callbacks.clear()
super().close()
super(SafeChildWatcher, self).close()
def __enter__(self):
return self
@@ -850,7 +886,7 @@ class FastChildWatcher(BaseChildWatcher):
(O(1) each time a child terminates).
"""
def __init__(self):
super().__init__()
super(FastChildWatcher, self).__init__()
self._callbacks = {}
self._lock = threading.Lock()
self._zombies = {}
@@ -859,7 +895,7 @@ class FastChildWatcher(BaseChildWatcher):
def close(self):
self._callbacks.clear()
self._zombies.clear()
super().close()
super(FastChildWatcher, self).close()
def __enter__(self):
with self._lock:
@@ -906,7 +942,7 @@ class FastChildWatcher(BaseChildWatcher):
# long as we're able to reap a child.
while True:
try:
pid, status = os.waitpid(-1, os.WNOHANG)
pid, status = wrap_error(os.waitpid, -1, os.WNOHANG)
except ChildProcessError:
# No more child processes exist.
return
@@ -949,7 +985,7 @@ class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy):
_loop_factory = _UnixSelectorEventLoop
def __init__(self):
super().__init__()
super(_UnixDefaultEventLoopPolicy, self).__init__()
self._watcher = None
def _init_watcher(self):
@@ -968,7 +1004,7 @@ class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy):
the child watcher.
"""
super().set_event_loop(loop)
super(_UnixDefaultEventLoopPolicy, self).set_event_loop(loop)
if self._watcher is not None and \
isinstance(threading.current_thread(), threading._MainThread):

View File

@@ -11,12 +11,15 @@ from . import events
from . import base_subprocess
from . import futures
from . import proactor_events
from . import py33_winapi as _winapi
from . import selector_events
from . import tasks
from . import windows_utils
from . import _overlapped
from .coroutines import coroutine
from .coroutines import coroutine, From, Return
from .log import logger
from .py33_exceptions import (wrap_error, get_error_class,
ConnectionRefusedError, BrokenPipeError)
__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor',
@@ -42,14 +45,14 @@ class _OverlappedFuture(futures.Future):
Cancelling it will immediately cancel the overlapped operation.
"""
def __init__(self, ov, *, loop=None):
super().__init__(loop=loop)
def __init__(self, ov, loop=None):
super(_OverlappedFuture, self).__init__(loop=loop)
if self._source_traceback:
del self._source_traceback[-1]
self._ov = ov
def _repr_info(self):
info = super()._repr_info()
info = super(_OverlappedFuture, self)._repr_info()
if self._ov is not None:
state = 'pending' if self._ov.pending else 'completed'
info.insert(1, 'overlapped=<%s, %#x>' % (state, self._ov.address))
@@ -73,22 +76,22 @@ class _OverlappedFuture(futures.Future):
def cancel(self):
self._cancel_overlapped()
return super().cancel()
return super(_OverlappedFuture, self).cancel()
def set_exception(self, exception):
super().set_exception(exception)
super(_OverlappedFuture, self).set_exception(exception)
self._cancel_overlapped()
def set_result(self, result):
super().set_result(result)
super(_OverlappedFuture, self).set_result(result)
self._ov = None
class _BaseWaitHandleFuture(futures.Future):
"""Subclass of Future which represents a wait handle."""
def __init__(self, ov, handle, wait_handle, *, loop=None):
super().__init__(loop=loop)
def __init__(self, ov, handle, wait_handle, loop=None):
super(_BaseWaitHandleFuture, self).__init__(loop=loop)
if self._source_traceback:
del self._source_traceback[-1]
# Keep a reference to the Overlapped object to keep it alive until the
@@ -107,7 +110,7 @@ class _BaseWaitHandleFuture(futures.Future):
_winapi.WAIT_OBJECT_0)
def _repr_info(self):
info = super()._repr_info()
info = super(_BaseWaitHandleFuture, self)._repr_info()
info.append('handle=%#x' % self._handle)
if self._handle is not None:
state = 'signaled' if self._poll() else 'waiting'
@@ -147,15 +150,15 @@ class _BaseWaitHandleFuture(futures.Future):
def cancel(self):
self._unregister_wait()
return super().cancel()
return super(_BaseWaitHandleFuture, self).cancel()
def set_exception(self, exception):
self._unregister_wait()
super().set_exception(exception)
super(_BaseWaitHandleFuture, self).set_exception(exception)
def set_result(self, result):
self._unregister_wait()
super().set_result(result)
super(_BaseWaitHandleFuture, self).set_result(result)
class _WaitCancelFuture(_BaseWaitHandleFuture):
@@ -163,8 +166,9 @@ class _WaitCancelFuture(_BaseWaitHandleFuture):
_WaitHandleFuture using an event.
"""
def __init__(self, ov, event, wait_handle, *, loop=None):
super().__init__(ov, event, wait_handle, loop=loop)
def __init__(self, ov, event, wait_handle, loop=None):
super(_WaitCancelFuture, self).__init__(ov, event, wait_handle,
loop=loop)
self._done_callback = None
@@ -178,8 +182,9 @@ class _WaitCancelFuture(_BaseWaitHandleFuture):
class _WaitHandleFuture(_BaseWaitHandleFuture):
def __init__(self, ov, handle, wait_handle, proactor, *, loop=None):
super().__init__(ov, handle, wait_handle, loop=loop)
def __init__(self, ov, handle, wait_handle, proactor, loop=None):
super(_WaitHandleFuture, self).__init__(ov, handle, wait_handle,
loop=loop)
self._proactor = proactor
self._unregister_proactor = True
self._event = _overlapped.CreateEvent(None, True, False, None)
@@ -201,7 +206,7 @@ class _WaitHandleFuture(_BaseWaitHandleFuture):
self._proactor._unregister(self._ov)
self._proactor = None
super()._unregister_wait_cb(fut)
super(_WaitHandleFuture, self)._unregister_wait_cb(fut)
def _unregister_wait(self):
if not self._registered:
@@ -259,7 +264,7 @@ class PipeServer(object):
flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED
if first:
flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
h = _winapi.CreateNamedPipe(
h = wrap_error(_winapi.CreateNamedPipe,
self._address, flags,
_winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE |
_winapi.PIPE_WAIT,
@@ -301,7 +306,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
def __init__(self, proactor=None):
if proactor is None:
proactor = IocpProactor()
super().__init__(proactor)
super(ProactorEventLoop, self).__init__(proactor)
def _socketpair(self):
return windows_utils.socketpair()
@@ -309,11 +314,11 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
@coroutine
def create_pipe_connection(self, protocol_factory, address):
f = self._proactor.connect_pipe(address)
pipe = yield from f
pipe = yield From(f)
protocol = protocol_factory()
trans = self._make_duplex_pipe_transport(pipe, protocol,
extra={'addr': address})
return trans, protocol
raise Return(trans, protocol)
@coroutine
def start_serving_pipe(self, protocol_factory, address):
@@ -372,7 +377,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
waiter=waiter, extra=extra,
**kwargs)
try:
yield from waiter
yield From(waiter)
except Exception as exc:
# Workaround CPython bug #23353: using yield/yield-from in an
# except block of a generator doesn't clear properly sys.exc_info()
@@ -382,13 +387,13 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
if err is not None:
transp.close()
yield from transp._wait()
yield From(transp._wait())
raise err
return transp
raise Return(transp)
class IocpProactor:
class IocpProactor(object):
"""Proactor implementation using IOCP."""
def __init__(self, concurrency=0xffffffff):
@@ -426,16 +431,16 @@ class IocpProactor:
ov = _overlapped.Overlapped(NULL)
try:
if isinstance(conn, socket.socket):
ov.WSARecv(conn.fileno(), nbytes, flags)
wrap_error(ov.WSARecv, conn.fileno(), nbytes, flags)
else:
ov.ReadFile(conn.fileno(), nbytes)
wrap_error(ov.ReadFile, conn.fileno(), nbytes)
except BrokenPipeError:
return self._result(b'')
def finish_recv(trans, key, ov):
try:
return ov.getresult()
except OSError as exc:
return wrap_error(ov.getresult)
except WindowsError as exc:
if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
raise ConnectionResetError(*exc.args)
else:
@@ -453,8 +458,8 @@ class IocpProactor:
def finish_send(trans, key, ov):
try:
return ov.getresult()
except OSError as exc:
return wrap_error(ov.getresult)
except WindowsError as exc:
if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
raise ConnectionResetError(*exc.args)
else:
@@ -469,7 +474,7 @@ class IocpProactor:
ov.AcceptEx(listener.fileno(), conn.fileno())
def finish_accept(trans, key, ov):
ov.getresult()
wrap_error(ov.getresult)
# Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work.
buf = struct.pack('@P', listener.fileno())
conn.setsockopt(socket.SOL_SOCKET,
@@ -481,7 +486,7 @@ class IocpProactor:
def accept_coro(future, conn):
# Coroutine closing the accept socket if the future is cancelled
try:
yield from future
yield From(future)
except futures.CancelledError:
conn.close()
raise
@@ -496,7 +501,7 @@ class IocpProactor:
# The socket needs to be locally bound before we call ConnectEx().
try:
_overlapped.BindLocal(conn.fileno(), conn.family)
except OSError as e:
except WindowsError as e:
if e.winerror != errno.WSAEINVAL:
raise
# Probably already locally bound; check using getsockname().
@@ -506,7 +511,7 @@ class IocpProactor:
ov.ConnectEx(conn.fileno(), address)
def finish_connect(trans, key, ov):
ov.getresult()
wrap_error(ov.getresult)
# Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work.
conn.setsockopt(socket.SOL_SOCKET,
_overlapped.SO_UPDATE_CONNECT_CONTEXT, 0)
@@ -526,7 +531,7 @@ class IocpProactor:
return self._result(pipe)
def finish_accept_pipe(trans, key, ov):
ov.getresult()
wrap_error(ov.getresult)
return pipe
return self._register(ov, pipe, finish_accept_pipe)
@@ -539,17 +544,17 @@ class IocpProactor:
# Call CreateFile() in a loop until it doesn't fail with
# ERROR_PIPE_BUSY
try:
handle = _overlapped.ConnectPipe(address)
handle = wrap_error(_overlapped.ConnectPipe, address)
break
except OSError as exc:
except WindowsError as exc:
if exc.winerror != _overlapped.ERROR_PIPE_BUSY:
raise
# ConnectPipe() failed with ERROR_PIPE_BUSY: retry later
delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY)
yield from tasks.sleep(delay, loop=self._loop)
yield From(tasks.sleep(delay, loop=self._loop))
return windows_utils.PipeHandle(handle)
raise Return(windows_utils.PipeHandle(handle))
def wait_for_handle(self, handle, timeout=None):
"""Wait for a handle.
@@ -572,7 +577,7 @@ class IocpProactor:
else:
# RegisterWaitForSingleObject() has a resolution of 1 millisecond,
# round away from zero to wait *at least* timeout seconds.
ms = math.ceil(timeout * 1e3)
ms = int(math.ceil(timeout * 1e3))
# We only create ov so we can use ov.address as a key for the cache.
ov = _overlapped.Overlapped(NULL)
@@ -660,7 +665,7 @@ class IocpProactor:
else:
# GetQueuedCompletionStatus() has a resolution of 1 millisecond,
# round away from zero to wait *at least* timeout seconds.
ms = math.ceil(timeout * 1e3)
ms = int(math.ceil(timeout * 1e3))
if ms >= INFINITE:
raise ValueError("timeout too big")
@@ -705,7 +710,7 @@ class IocpProactor:
# Remove unregisted futures
for ov in self._unregistered:
self._cache.pop(ov.address, None)
self._unregistered.clear()
del self._unregistered[:]
def _stop_serving(self, obj):
# obj is a socket or pipe handle. It will be closed in

View File

@@ -1,6 +1,7 @@
"""
Various Windows specific bits and pieces
"""
from __future__ import absolute_import
import sys
@@ -16,6 +17,9 @@ import subprocess
import tempfile
import warnings
from . import py33_winapi as _winapi
from .py33_exceptions import wrap_error, BlockingIOError, InterruptedError
__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle']
@@ -64,7 +68,7 @@ else:
try:
csock.setblocking(False)
try:
csock.connect((addr, port))
wrap_error(csock.connect, (addr, port))
except (BlockingIOError, InterruptedError):
pass
csock.setblocking(True)
@@ -80,7 +84,7 @@ else:
# Replacement for os.pipe() using handles instead of fds
def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
def pipe(duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
"""Like os.pipe() but with overlapped support and using handles not fds."""
address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' %
(os.getpid(), next(_mmap_counter)))
@@ -115,7 +119,12 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
flags_and_attribs, _winapi.NULL)
ov = _winapi.ConnectNamedPipe(h1, overlapped=True)
ov.GetOverlappedResult(True)
if hasattr(ov, 'GetOverlappedResult'):
# _winapi module of Python 3.3
ov.GetOverlappedResult(True)
else:
# _overlapped module
wrap_error(ov.getresult, True)
return h1, h2
except:
if h1 is not None:
@@ -128,7 +137,7 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
# Wrapper for a pipe handle
class PipeHandle:
class PipeHandle(object):
"""Wrapper for an overlapped pipe handle which is vaguely file-object like.
The IOCP event loop can use these instead of socket objects.
@@ -152,7 +161,7 @@ class PipeHandle:
raise ValueError("I/O operatioon on closed pipe")
return self._handle
def close(self, *, CloseHandle=_winapi.CloseHandle):
def close(self, CloseHandle=_winapi.CloseHandle):
if self._handle is not None:
CloseHandle(self._handle)
self._handle = None
@@ -200,8 +209,11 @@ class Popen(subprocess.Popen):
else:
stderr_wfd = stderr
try:
super().__init__(args, stdin=stdin_rfd, stdout=stdout_wfd,
stderr=stderr_wfd, **kwds)
super(Popen, self).__init__(args,
stdin=stdin_rfd,
stdout=stdout_wfd,
stderr=stderr_wfd,
**kwds)
except:
for h in (stdin_wh, stdout_rh, stderr_rh):
if h is not None: