From a4749501b494ecda86c2092b4d4b86cc5e4bf2ff Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 7 Jul 2015 21:17:17 +0200 Subject: [PATCH] Port asyncio to Python 2, trollius/ directory --- trollius/__init__.py | 11 +- trollius/base_events.py | 155 +++++++------ trollius/base_subprocess.py | 7 +- trollius/coroutines.py | 134 ++++++++++- trollius/events.py | 442 +++++++++++++++++++----------------- trollius/futures.py | 98 ++++++-- trollius/locks.py | 86 +++---- trollius/proactor_events.py | 20 +- trollius/protocols.py | 2 +- trollius/queues.py | 30 +-- trollius/selector_events.py | 126 +++++----- trollius/selectors.py | 72 +++--- trollius/sslproto.py | 20 +- trollius/streams.py | 74 +++--- trollius/subprocess.py | 55 +++-- trollius/tasks.py | 242 ++++++++++++-------- trollius/test_support.py | 17 +- trollius/test_utils.py | 235 +++++++++++++++---- trollius/transports.py | 15 +- trollius/unix_events.py | 148 +++++++----- trollius/windows_events.py | 93 ++++---- trollius/windows_utils.py | 26 ++- 22 files changed, 1321 insertions(+), 787 deletions(-) diff --git a/trollius/__init__.py b/trollius/__init__.py index 011466b..a1379fb 100644 --- a/trollius/__init__.py +++ b/trollius/__init__.py @@ -1,4 +1,4 @@ -"""The asyncio package, tracking PEP 3156.""" +"""The trollius package, tracking PEP 3156.""" import sys @@ -24,6 +24,7 @@ from .events import * from .futures import * from .locks import * from .protocols import * +from .py33_exceptions import * from .queues import * from .streams import * from .subprocess import * @@ -33,6 +34,7 @@ from .transports import * __all__ = (base_events.__all__ + coroutines.__all__ + events.__all__ + + py33_exceptions.__all__ + futures.__all__ + locks.__all__ + protocols.__all__ + @@ -48,3 +50,10 @@ if sys.platform == 'win32': # pragma: no cover else: from .unix_events import * # pragma: no cover __all__ += unix_events.__all__ + +try: + from .py3_ssl import * + __all__ += py3_ssl.__all__ +except ImportError: + # SSL support is optionnal + pass diff --git a/trollius/base_events.py b/trollius/base_events.py index 5a536a2..c8541f1 100644 --- a/trollius/base_events.py +++ b/trollius/base_events.py @@ -15,25 +15,34 @@ to modify the meaning of the API call itself. import collections -import concurrent.futures import heapq import inspect import logging import os import socket import subprocess -import threading -import time -import traceback import sys -import warnings +import traceback +try: + from collections import OrderedDict +except ImportError: + # Python 2.6: use ordereddict backport + from ordereddict import OrderedDict +try: + from threading import get_ident as _get_thread_ident +except ImportError: + # Python 2 + from threading import _get_ident as _get_thread_ident +from . import compat from . import coroutines from . import events from . import futures from . import tasks -from .coroutines import coroutine +from .coroutines import coroutine, From, Return +from .executor import get_default_executor from .log import logger +from .time_monotonic import time_monotonic, time_monotonic_resolution __all__ = ['BaseEventLoop'] @@ -171,10 +180,10 @@ class Server(events.AbstractServer): @coroutine def wait_closed(self): if self.sockets is None or self._waiters is None: - return + raise Return() waiter = futures.Future(loop=self._loop) self._waiters.append(waiter) - yield from waiter + yield From(waiter) class BaseEventLoop(events.AbstractEventLoop): @@ -191,8 +200,7 @@ class BaseEventLoop(events.AbstractEventLoop): self._thread_id = None self._clock_resolution = time.get_clock_info('monotonic').resolution self._exception_handler = None - self.set_debug((not sys.flags.ignore_environment - and bool(os.environ.get('PYTHONASYNCIODEBUG')))) + self.set_debug(bool(os.environ.get('TROLLIUSDEBUG'))) # In debug mode, if the execution of a callback or a step of a task # exceed this duration in seconds, the slow callback/task is logged. self.slow_callback_duration = 0.1 @@ -237,13 +245,13 @@ class BaseEventLoop(events.AbstractEventLoop): """Return a task factory, or None if the default one is in use.""" return self._task_factory - def _make_socket_transport(self, sock, protocol, waiter=None, *, + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, server=None): """Create socket transport.""" raise NotImplementedError def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, + server_side=False, server_hostname=None, extra=None, server=None): """Create SSL transport.""" raise NotImplementedError @@ -506,7 +514,7 @@ class BaseEventLoop(events.AbstractEventLoop): if executor is None: executor = self._default_executor if executor is None: - executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + executor = get_default_executor() self._default_executor = executor return futures.wrap_future(executor.submit(func, *args), loop=self) @@ -538,7 +546,7 @@ class BaseEventLoop(events.AbstractEventLoop): logger.debug(msg) return addrinfo - def getaddrinfo(self, host, port, *, + def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): if self._debug: return self.run_in_executor(None, self._getaddrinfo_debug, @@ -551,7 +559,7 @@ class BaseEventLoop(events.AbstractEventLoop): return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) @coroutine - def create_connection(self, protocol_factory, host=None, port=None, *, + def create_connection(self, protocol_factory, host=None, port=None, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None): """Connect to a TCP server. @@ -601,15 +609,15 @@ class BaseEventLoop(events.AbstractEventLoop): else: f2 = None - yield from tasks.wait(fs, loop=self) + yield From(tasks.wait(fs, loop=self)) infos = f1.result() if not infos: - raise OSError('getaddrinfo() returned empty list') + raise socket.error('getaddrinfo() returned empty list') if f2 is not None: laddr_infos = f2.result() if not laddr_infos: - raise OSError('getaddrinfo() returned empty list') + raise socket.error('getaddrinfo() returned empty list') exceptions = [] for family, type, proto, cname, address in infos: @@ -621,11 +629,11 @@ class BaseEventLoop(events.AbstractEventLoop): try: sock.bind(laddr) break - except OSError as exc: - exc = OSError( + except socket.error as exc: + exc = socket.error( exc.errno, 'error while ' 'attempting to bind on address ' - '{!r}: {}'.format( + '{0!r}: {1}'.format( laddr, exc.strerror.lower())) exceptions.append(exc) else: @@ -634,8 +642,8 @@ class BaseEventLoop(events.AbstractEventLoop): continue if self._debug: logger.debug("connect %r to %r", sock, address) - yield from self.sock_connect(sock, address) - except OSError as exc: + yield From(self.sock_connect(sock, address)) + except socket.error as exc: if sock is not None: sock.close() exceptions.append(exc) @@ -655,7 +663,7 @@ class BaseEventLoop(events.AbstractEventLoop): raise exceptions[0] # Raise a combined exception so the user can see all # the various error messages. - raise OSError('Multiple exceptions: {}'.format( + raise socket.error('Multiple exceptions: {0}'.format( ', '.join(str(exc) for exc in exceptions))) elif sock is None: @@ -664,15 +672,15 @@ class BaseEventLoop(events.AbstractEventLoop): sock.setblocking(False) - transport, protocol = yield from self._create_connection_transport( - sock, protocol_factory, ssl, server_hostname) + transport, protocol = yield From(self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname)) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket sock = transport.get_extra_info('socket') logger.debug("%r connected to %s:%r: (%r, %r)", sock, host, port, transport, protocol) - return transport, protocol + raise Return(transport, protocol) @coroutine def _create_connection_transport(self, sock, protocol_factory, ssl, @@ -688,12 +696,12 @@ class BaseEventLoop(events.AbstractEventLoop): transport = self._make_socket_transport(sock, protocol, waiter) try: - yield from waiter + yield From(waiter) except: transport.close() raise - return transport, protocol + raise Return(transport, protocol) @coroutine def create_datagram_endpoint(self, protocol_factory, @@ -706,17 +714,17 @@ class BaseEventLoop(events.AbstractEventLoop): addr_pairs_info = (((family, proto), (None, None)),) else: # join address by (family, protocol) - addr_infos = collections.OrderedDict() + addr_infos = OrderedDict() for idx, addr in ((0, local_addr), (1, remote_addr)): if addr is not None: assert isinstance(addr, tuple) and len(addr) == 2, ( '2-tuple is expected') - infos = yield from self.getaddrinfo( + infos = yield From(self.getaddrinfo( *addr, family=family, type=socket.SOCK_DGRAM, - proto=proto, flags=flags) + proto=proto, flags=flags)) if not infos: - raise OSError('getaddrinfo() returned empty list') + raise socket.error('getaddrinfo() returned empty list') for fam, _, pro, _, address in infos: key = (fam, pro) @@ -748,9 +756,9 @@ class BaseEventLoop(events.AbstractEventLoop): if local_addr: sock.bind(local_address) if remote_addr: - yield from self.sock_connect(sock, remote_address) + yield From(self.sock_connect(sock, remote_address)) r_addr = remote_address - except OSError as exc: + except socket.error as exc: if sock is not None: sock.close() exceptions.append(exc) @@ -778,16 +786,15 @@ class BaseEventLoop(events.AbstractEventLoop): remote_addr, transport, protocol) try: - yield from waiter + yield From(waiter) except: transport.close() raise - return transport, protocol + raise Return(transport, protocol) @coroutine def create_server(self, protocol_factory, host=None, port=None, - *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, @@ -814,11 +821,11 @@ class BaseEventLoop(events.AbstractEventLoop): if host == '': host = None - infos = yield from self.getaddrinfo( + infos = yield From(self.getaddrinfo( host, port, family=family, - type=socket.SOCK_STREAM, proto=0, flags=flags) + type=socket.SOCK_STREAM, proto=0, flags=flags)) if not infos: - raise OSError('getaddrinfo() returned empty list') + raise socket.error('getaddrinfo() returned empty list') completed = False try: @@ -846,10 +853,11 @@ class BaseEventLoop(events.AbstractEventLoop): True) try: sock.bind(sa) - except OSError as err: - raise OSError(err.errno, 'error while attempting ' - 'to bind on address %r: %s' - % (sa, err.strerror.lower())) + except socket.error as err: + raise socket.error(err.errno, + 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) completed = True finally: if not completed: @@ -867,7 +875,7 @@ class BaseEventLoop(events.AbstractEventLoop): self._start_serving(protocol_factory, sock, ssl, server) if self._debug: logger.info("%r is serving", server) - return server + raise Return(server) @coroutine def connect_read_pipe(self, protocol_factory, pipe): @@ -876,7 +884,7 @@ class BaseEventLoop(events.AbstractEventLoop): transport = self._make_read_pipe_transport(pipe, protocol, waiter) try: - yield from waiter + yield From(waiter) except: transport.close() raise @@ -884,7 +892,7 @@ class BaseEventLoop(events.AbstractEventLoop): if self._debug: logger.debug('Read pipe %r connected: (%r, %r)', pipe.fileno(), transport, protocol) - return transport, protocol + raise Return(transport, protocol) @coroutine def connect_write_pipe(self, protocol_factory, pipe): @@ -893,7 +901,7 @@ class BaseEventLoop(events.AbstractEventLoop): transport = self._make_write_pipe_transport(pipe, protocol, waiter) try: - yield from waiter + yield From(waiter) except: transport.close() raise @@ -901,7 +909,7 @@ class BaseEventLoop(events.AbstractEventLoop): if self._debug: logger.debug('Write pipe %r connected: (%r, %r)', pipe.fileno(), transport, protocol) - return transport, protocol + raise Return(transport, protocol) def _log_subprocess(self, msg, stdin, stdout, stderr): info = [msg] @@ -917,11 +925,11 @@ class BaseEventLoop(events.AbstractEventLoop): logger.debug(' '.join(info)) @coroutine - def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + def subprocess_shell(self, protocol_factory, cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=False, shell=True, bufsize=0, **kwargs): - if not isinstance(cmd, (bytes, str)): + if not isinstance(cmd, compat.string_types): raise ValueError("cmd must be a string") if universal_newlines: raise ValueError("universal_newlines must be False") @@ -935,17 +943,20 @@ class BaseEventLoop(events.AbstractEventLoop): # (password) and may be too long debug_log = 'run shell command %r' % cmd self._log_subprocess(debug_log, stdin, stdout, stderr) - transport = yield from self._make_subprocess_transport( - protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) + transport = yield From(self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs)) if self._debug: logger.info('%s: %r' % (debug_log, transport)) - return transport, protocol + raise Return(transport, protocol) @coroutine - def subprocess_exec(self, protocol_factory, program, *args, - stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, universal_newlines=False, - shell=False, bufsize=0, **kwargs): + def subprocess_exec(self, protocol_factory, program, *args, **kwargs): + stdin = kwargs.pop('stdin', subprocess.PIPE) + stdout = kwargs.pop('stdout', subprocess.PIPE) + stderr = kwargs.pop('stderr', subprocess.PIPE) + universal_newlines = kwargs.pop('universal_newlines', False) + shell = kwargs.pop('shell', False) + bufsize = kwargs.pop('bufsize', 0) if universal_newlines: raise ValueError("universal_newlines must be False") if shell: @@ -954,7 +965,7 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError("bufsize must be 0") popen_args = (program,) + args for arg in popen_args: - if not isinstance(arg, (str, bytes)): + if not isinstance(arg, compat.string_types ): raise TypeError("program arguments must be " "a bytes or text string, not %s" % type(arg).__name__) @@ -964,12 +975,12 @@ class BaseEventLoop(events.AbstractEventLoop): # (password) and may be too long debug_log = 'execute program %r' % program self._log_subprocess(debug_log, stdin, stdout, stderr) - transport = yield from self._make_subprocess_transport( + transport = yield From(self._make_subprocess_transport( protocol, popen_args, False, stdin, stdout, stderr, - bufsize, **kwargs) + bufsize, **kwargs)) if self._debug: logger.info('%s: %r' % (debug_log, transport)) - return transport, protocol + raise Return(transport, protocol) def set_exception_handler(self, handler): """Set handler as the new event loop exception handler. @@ -985,7 +996,7 @@ class BaseEventLoop(events.AbstractEventLoop): """ if handler is not None and not callable(handler): raise TypeError('A callable object or None is expected, ' - 'got {!r}'.format(handler)) + 'got {0!r}'.format(handler)) self._exception_handler = handler def default_exception_handler(self, context): @@ -1004,7 +1015,15 @@ class BaseEventLoop(events.AbstractEventLoop): exception = context.get('exception') if exception is not None: - exc_info = (type(exception), exception, exception.__traceback__) + if hasattr(exception, '__traceback__'): + # Python 3 + tb = exception.__traceback__ + else: + # call_exception_handler() is usually called indirectly + # from an except block. If it's not the case, the traceback + # is undefined... + tb = sys.exc_info()[2] + exc_info = (type(exception), exception, tb) else: exc_info = False @@ -1015,7 +1034,7 @@ class BaseEventLoop(events.AbstractEventLoop): log_lines = [message] for key in sorted(context): - if key in {'message', 'exception'}: + if key in ('message', 'exception'): continue value = context[key] if key == 'source_traceback': @@ -1028,7 +1047,7 @@ class BaseEventLoop(events.AbstractEventLoop): value += tb.rstrip() else: value = repr(value) - log_lines.append('{}: {}'.format(key, value)) + log_lines.append('{0}: {1}'.format(key, value)) logger.error('\n'.join(log_lines), exc_info=exc_info) @@ -1108,7 +1127,7 @@ class BaseEventLoop(events.AbstractEventLoop): sched_count = len(self._scheduled) if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and - self._timer_cancelled_count / sched_count > + float(self._timer_cancelled_count) / sched_count > _MIN_CANCELLED_TIMER_HANDLES_FRACTION): # Remove delayed calls that were cancelled if their number # is too high diff --git a/trollius/base_subprocess.py b/trollius/base_subprocess.py index c1477b8..d3a6465 100644 --- a/trollius/base_subprocess.py +++ b/trollius/base_subprocess.py @@ -6,7 +6,7 @@ import warnings from . import futures from . import protocols from . import transports -from .coroutines import coroutine +from .coroutines import coroutine, From from .log import logger @@ -15,7 +15,7 @@ class BaseSubprocessTransport(transports.SubprocessTransport): def __init__(self, loop, protocol, args, shell, stdin, stdout, stderr, bufsize, waiter=None, extra=None, **kwargs): - super().__init__(extra) + super(BaseSubprocessTransport, self).__init__(extra) self._closed = False self._protocol = protocol self._loop = loop @@ -221,7 +221,8 @@ class BaseSubprocessTransport(transports.SubprocessTransport): waiter = futures.Future(loop=self._loop) self._exit_waiters.append(waiter) - return (yield from waiter) + returncode = yield From(waiter) + return returncode def _try_finish(self): assert not self._finished diff --git a/trollius/coroutines.py b/trollius/coroutines.py index 15475f2..b12ca3e 100644 --- a/trollius/coroutines.py +++ b/trollius/coroutines.py @@ -9,6 +9,7 @@ import sys import traceback import types +from . import compat from . import events from . import futures from .log import logger @@ -18,7 +19,7 @@ _PY35 = sys.version_info >= (3, 5) # Opcode of "yield from" instruction -_YIELD_FROM = opcode.opmap['YIELD_FROM'] +_YIELD_FROM = opcode.opmap.get('YIELD_FROM', None) # If you set _DEBUG to true, @coroutine will wrap the resulting # generator objects in a CoroWrapper instance (defined below). That @@ -29,8 +30,7 @@ _YIELD_FROM = opcode.opmap['YIELD_FROM'] # before you define your coroutines. A downside of using this feature # is that tracebacks show entries for the CoroWrapper.__next__ method # when _DEBUG is true. -_DEBUG = (not sys.flags.ignore_environment - and bool(os.environ.get('PYTHONASYNCIODEBUG'))) +_DEBUG = bool(os.environ.get('TROLLIUSDEBUG')) try: @@ -74,6 +74,53 @@ _YIELD_FROM_BUG = has_yield_from_bug() del has_yield_from_bug +if compat.PY33: + # Don't use the Return class on Python 3.3 and later to support asyncio + # coroutines (to avoid the warning emited in Return destructor). + # + # The problem is that Return inherits from StopIteration. "yield from + # trollius_coroutine". Task._step() does not receive the Return exception, + # because "yield from" handles it internally. So it's not possible to set + # the raised attribute to True to avoid the warning in Return destructor. + def Return(*args): + if not args: + value = None + elif len(args) == 1: + value = args[0] + else: + value = args + return StopIteration(value) +else: + class Return(StopIteration): + def __init__(self, *args): + StopIteration.__init__(self) + if not args: + self.value = None + elif len(args) == 1: + self.value = args[0] + else: + self.value = args + self.raised = False + if _DEBUG: + frame = sys._getframe(1) + self._source_traceback = traceback.extract_stack(frame) + # explicitly clear the reference to avoid reference cycles + frame = None + else: + self._source_traceback = None + + def __del__(self): + if self.raised: + return + + fmt = 'Return(%r) used without raise' + if self._source_traceback: + fmt += '\nReturn created at (most recent call last):\n' + tb = ''.join(traceback.format_list(self._source_traceback)) + fmt += tb.rstrip() + logger.error(fmt, self.value) + + def debug_wrapper(gen): # This function is called from 'sys.set_coroutine_wrapper'. # We only wrap here coroutines defined via 'async def' syntax. @@ -104,7 +151,8 @@ class CoroWrapper: return self def __next__(self): - return self.gen.send(None) + return next(self.gen) + next = __next__ if _YIELD_FROM_BUG: # For for CPython issue #21209: using "yield from" and a custom @@ -180,6 +228,56 @@ class CoroWrapper: msg += tb.rstrip() logger.error(msg) +if not compat.PY34: + # Backport functools.update_wrapper() from Python 3.4: + # - Python 2.7 fails if assigned attributes don't exist + # - Python 2.7 and 3.1 don't set the __wrapped__ attribute + # - Python 3.2 and 3.3 set __wrapped__ before updating __dict__ + def _update_wrapper(wrapper, + wrapped, + assigned = functools.WRAPPER_ASSIGNMENTS, + updated = functools.WRAPPER_UPDATES): + """Update a wrapper function to look like the wrapped function + + wrapper is the function to be updated + wrapped is the original function + assigned is a tuple naming the attributes assigned directly + from the wrapped function to the wrapper function (defaults to + functools.WRAPPER_ASSIGNMENTS) + updated is a tuple naming the attributes of the wrapper that + are updated with the corresponding attribute from the wrapped + function (defaults to functools.WRAPPER_UPDATES) + """ + for attr in assigned: + try: + value = getattr(wrapped, attr) + except AttributeError: + pass + else: + setattr(wrapper, attr, value) + for attr in updated: + getattr(wrapper, attr).update(getattr(wrapped, attr, {})) + # Issue #17482: set __wrapped__ last so we don't inadvertently copy it + # from the wrapped function when updating __dict__ + wrapper.__wrapped__ = wrapped + # Return the wrapper so this can be used as a decorator via partial() + return wrapper + + def _wraps(wrapped, + assigned = functools.WRAPPER_ASSIGNMENTS, + updated = functools.WRAPPER_UPDATES): + """Decorator factory to apply update_wrapper() to a wrapper function + + Returns a decorator that invokes update_wrapper() with the decorated + function as the wrapper argument and the arguments to wraps() as the + remaining arguments. Default arguments are as for update_wrapper(). + This is a convenience function to simplify applying partial() to + update_wrapper(). + """ + return functools.partial(_update_wrapper, wrapped=wrapped, + assigned=assigned, updated=updated) +else: + _wraps = functools.wraps def coroutine(func): """Decorator to mark coroutines. @@ -197,7 +295,7 @@ def coroutine(func): if inspect.isgeneratorfunction(func): coro = func else: - @functools.wraps(func) + @_wraps(func) def coro(*args, **kw): res = func(*args, **kw) if isinstance(res, futures.Future) or inspect.isgenerator(res): @@ -220,7 +318,7 @@ def coroutine(func): else: wrapper = _types_coroutine(coro) else: - @functools.wraps(func) + @_wraps(func) def wrapper(*args, **kwds): w = CoroWrapper(coro(*args, **kwds), func=func) if w._source_traceback: @@ -246,7 +344,13 @@ def iscoroutinefunction(func): _COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) if _CoroutineABC is not None: _COROUTINE_TYPES += (_CoroutineABC,) - +if events.asyncio is not None: + # Accept also asyncio CoroWrapper for interoperability + if hasattr(events.asyncio, 'coroutines'): + _COROUTINE_TYPES += (events.asyncio.coroutines.CoroWrapper,) + else: + # old Tulip/Python versions + _COROUTINE_TYPES += (events.asyncio.tasks.CoroWrapper,) def iscoroutine(obj): """Return True if obj is a coroutine object.""" @@ -299,3 +403,19 @@ def _format_coroutine(coro): % (coro_name, filename, lineno)) return coro_repr + + +class FromWrapper(object): + __slots__ = ('obj',) + + def __init__(self, obj): + if isinstance(obj, FromWrapper): + obj = obj.obj + assert not isinstance(obj, FromWrapper) + self.obj = obj + +def From(obj): + if not _DEBUG: + return obj + else: + return FromWrapper(obj) diff --git a/trollius/events.py b/trollius/events.py index 496075b..3aa5b69 100644 --- a/trollius/events.py +++ b/trollius/events.py @@ -10,12 +10,23 @@ __all__ = ['AbstractEventLoopPolicy', import functools import inspect -import reprlib import socket import subprocess import sys import threading import traceback +try: + import reprlib # Python 3 +except ImportError: + import repr as reprlib # Python 2 + +from trollius import compat +try: + import asyncio +except (ImportError, SyntaxError): + # ignore SyntaxError for convenience: ignore SyntaxError caused by "yield + # from" if asyncio module is in the Python path + asyncio = None _PY34 = sys.version_info >= (3, 4) @@ -75,7 +86,7 @@ def _format_callback_source(func, args): return func_repr -class Handle: +class Handle(object): """Object returned by callback registration methods.""" __slots__ = ('_callback', '_args', '_cancelled', '_loop', @@ -145,14 +156,14 @@ class TimerHandle(Handle): def __init__(self, when, callback, args, loop): assert when is not None - super().__init__(callback, args, loop) + super(TimerHandle, self).__init__(callback, args, loop) if self._source_traceback: del self._source_traceback[-1] self._when = when self._scheduled = False def _repr_info(self): - info = super()._repr_info() + info = super(TimerHandle, self)._repr_info() pos = 2 if self._cancelled else 1 info.insert(pos, 'when=%s' % self._when) return info @@ -191,10 +202,10 @@ class TimerHandle(Handle): def cancel(self): if not self._cancelled: self._loop._timer_handle_cancelled(self) - super().cancel() + super(TimerHandle, self).cancel() -class AbstractServer: +class AbstractServer(object): """Abstract server returned by create_server().""" def close(self): @@ -206,298 +217,303 @@ class AbstractServer: return NotImplemented -class AbstractEventLoop: - """Abstract event loop.""" +if asyncio is not None: + # Reuse asyncio classes so asyncio.set_event_loop() and + # asyncio.set_event_loop_policy() accept Trollius event loop and trollius + # event loop policy + AbstractEventLoop = asyncio.AbstractEventLoop + AbstractEventLoopPolicy = asyncio.AbstractEventLoopPolicy +else: + class AbstractEventLoop(object): + """Abstract event loop.""" - # Running and stopping the event loop. + # Running and stopping the event loop. - def run_forever(self): - """Run the event loop until stop() is called.""" - raise NotImplementedError + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError - def run_until_complete(self, future): - """Run the event loop until a Future is done. + def run_until_complete(self, future): + """Run the event loop until a Future is done. - Return the Future's result, or raise its exception. - """ - raise NotImplementedError + Return the Future's result, or raise its exception. + """ + raise NotImplementedError - def stop(self): - """Stop the event loop as soon as reasonable. + def stop(self): + """Stop the event loop as soon as reasonable. - Exactly how soon that is may depend on the implementation, but - no more I/O callbacks should be scheduled. - """ - raise NotImplementedError + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError - def is_running(self): - """Return whether the event loop is currently running.""" - raise NotImplementedError + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError - def is_closed(self): - """Returns True if the event loop was closed.""" - raise NotImplementedError + def is_closed(self): + """Returns True if the event loop was closed.""" + raise NotImplementedError - def close(self): - """Close the loop. + def close(self): + """Close the loop. - The loop should not be running. + The loop should not be running. - This is idempotent and irreversible. + This is idempotent and irreversible. - No other methods should be called after this one. - """ - raise NotImplementedError + No other methods should be called after this one. + """ + raise NotImplementedError - # Methods scheduling callbacks. All these return Handles. + # Methods scheduling callbacks. All these return Handles. - def _timer_handle_cancelled(self, handle): - """Notification that a TimerHandle has been cancelled.""" - raise NotImplementedError + def _timer_handle_cancelled(self, handle): + """Notification that a TimerHandle has been cancelled.""" + raise NotImplementedError - def call_soon(self, callback, *args): - return self.call_later(0, callback, *args) + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) - def call_later(self, delay, callback, *args): - raise NotImplementedError + def call_later(self, delay, callback, *args): + raise NotImplementedError - def call_at(self, when, callback, *args): - raise NotImplementedError + def call_at(self, when, callback, *args): + raise NotImplementedError - def time(self): - raise NotImplementedError + def time(self): + raise NotImplementedError - # Method scheduling a coroutine object: create a task. + # Method scheduling a coroutine object: create a task. - def create_task(self, coro): - raise NotImplementedError + def create_task(self, coro): + raise NotImplementedError - # Methods for interacting with threads. + # Methods for interacting with threads. - def call_soon_threadsafe(self, callback, *args): - raise NotImplementedError + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError - def run_in_executor(self, executor, func, *args): - raise NotImplementedError + def run_in_executor(self, executor, func, *args): + raise NotImplementedError - def set_default_executor(self, executor): - raise NotImplementedError + def set_default_executor(self, executor): + raise NotImplementedError - # Network I/O methods returning Futures. + # Network I/O methods returning Futures. - def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): - raise NotImplementedError + def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): + raise NotImplementedError - def getnameinfo(self, sockaddr, flags=0): - raise NotImplementedError + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError - def create_connection(self, protocol_factory, host=None, port=None, *, - ssl=None, family=0, proto=0, flags=0, sock=None, - local_addr=None, server_hostname=None): - raise NotImplementedError + def create_connection(self, protocol_factory, host=None, port=None, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None, server_hostname=None): + raise NotImplementedError - def create_server(self, protocol_factory, host=None, port=None, *, - family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, - sock=None, backlog=100, ssl=None, reuse_address=None): - """A coroutine which creates a TCP server bound to host and port. + def create_server(self, protocol_factory, host=None, port=None, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """A coroutine which creates a TCP server bound to host and port. - The return value is a Server object which can be used to stop - the service. + The return value is a Server object which can be used to stop + the service. - If host is an empty string or None all interfaces are assumed - and a list of multiple sockets will be returned (most likely - one for IPv4 and another one for IPv6). + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). - family can be set to either AF_INET or AF_INET6 to force the - socket to use IPv4 or IPv6. If not set it will be determined - from host (defaults to AF_UNSPEC). + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). - flags is a bitmask for getaddrinfo(). + flags is a bitmask for getaddrinfo(). - sock can optionally be specified in order to use a preexisting - socket object. + sock can optionally be specified in order to use a preexisting + socket object. - backlog is the maximum number of queued connections passed to - listen() (defaults to 100). + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). - ssl can be set to an SSLContext to enable SSL over the - accepted connections. + ssl can be set to an SSLContext to enable SSL over the + accepted connections. - reuse_address tells the kernel to reuse a local socket in - TIME_WAIT state, without waiting for its natural timeout to - expire. If not specified will automatically be set to True on - UNIX. - """ - raise NotImplementedError + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError - def create_unix_connection(self, protocol_factory, path, *, - ssl=None, sock=None, - server_hostname=None): - raise NotImplementedError + def create_unix_connection(self, protocol_factory, path, + ssl=None, sock=None, + server_hostname=None): + raise NotImplementedError - def create_unix_server(self, protocol_factory, path, *, - sock=None, backlog=100, ssl=None): - """A coroutine which creates a UNIX Domain Socket server. + def create_unix_server(self, protocol_factory, path, + sock=None, backlog=100, ssl=None): + """A coroutine which creates a UNIX Domain Socket server. - The return value is a Server object, which can be used to stop - the service. + The return value is a Server object, which can be used to stop + the service. - path is a str, representing a file systsem path to bind the - server socket to. + path is a str, representing a file systsem path to bind the + server socket to. - sock can optionally be specified in order to use a preexisting - socket object. + sock can optionally be specified in order to use a preexisting + socket object. - backlog is the maximum number of queued connections passed to - listen() (defaults to 100). + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). - ssl can be set to an SSLContext to enable SSL over the - accepted connections. - """ - raise NotImplementedError + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + """ + raise NotImplementedError - def create_datagram_endpoint(self, protocol_factory, - local_addr=None, remote_addr=None, *, - family=0, proto=0, flags=0): - raise NotImplementedError + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, + family=0, proto=0, flags=0): + raise NotImplementedError - # Pipes and subprocesses. + # Pipes and subprocesses. - def connect_read_pipe(self, protocol_factory, pipe): - """Register read pipe in event loop. Set the pipe to non-blocking mode. + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in event loop. Set the pipe to non-blocking mode. - protocol_factory should instantiate object with Protocol interface. - pipe is a file-like object. - Return pair (transport, protocol), where transport supports the - ReadTransport interface.""" - # The reason to accept file-like object instead of just file descriptor - # is: we need to own pipe and close it at transport finishing - # Can got complicated errors if pass f.fileno(), - # close fd in pipe transport then close f and vise versa. - raise NotImplementedError + protocol_factory should instantiate object with Protocol interface. + pipe is a file-like object. + Return pair (transport, protocol), where transport supports the + ReadTransport interface.""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError - def connect_write_pipe(self, protocol_factory, pipe): - """Register write pipe in event loop. + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in event loop. - protocol_factory should instantiate object with BaseProtocol interface. - Pipe is file-like object already switched to nonblocking. - Return pair (transport, protocol), where transport support - WriteTransport interface.""" - # The reason to accept file-like object instead of just file descriptor - # is: we need to own pipe and close it at transport finishing - # Can got complicated errors if pass f.fileno(), - # close fd in pipe transport then close f and vise versa. - raise NotImplementedError + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport interface.""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError - def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - **kwargs): - raise NotImplementedError + def subprocess_shell(self, protocol_factory, cmd, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError - def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - **kwargs): - raise NotImplementedError + def subprocess_exec(self, protocol_factory, *args, **kwargs): + raise NotImplementedError - # Ready-based callback registration methods. - # The add_*() methods return None. - # The remove_*() methods return True if something was removed, - # False if there was nothing to delete. + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. - def add_reader(self, fd, callback, *args): - raise NotImplementedError + def add_reader(self, fd, callback, *args): + raise NotImplementedError - def remove_reader(self, fd): - raise NotImplementedError + def remove_reader(self, fd): + raise NotImplementedError - def add_writer(self, fd, callback, *args): - raise NotImplementedError + def add_writer(self, fd, callback, *args): + raise NotImplementedError - def remove_writer(self, fd): - raise NotImplementedError + def remove_writer(self, fd): + raise NotImplementedError - # Completion based I/O methods returning Futures. + # Completion based I/O methods returning Futures. - def sock_recv(self, sock, nbytes): - raise NotImplementedError + def sock_recv(self, sock, nbytes): + raise NotImplementedError - def sock_sendall(self, sock, data): - raise NotImplementedError + def sock_sendall(self, sock, data): + raise NotImplementedError - def sock_connect(self, sock, address): - raise NotImplementedError + def sock_connect(self, sock, address): + raise NotImplementedError - def sock_accept(self, sock): - raise NotImplementedError + def sock_accept(self, sock): + raise NotImplementedError - # Signal handling. + # Signal handling. - def add_signal_handler(self, sig, callback, *args): - raise NotImplementedError + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError - def remove_signal_handler(self, sig): - raise NotImplementedError + def remove_signal_handler(self, sig): + raise NotImplementedError - # Task factory. + # Task factory. - def set_task_factory(self, factory): - raise NotImplementedError + def set_task_factory(self, factory): + raise NotImplementedError - def get_task_factory(self): - raise NotImplementedError + def get_task_factory(self): + raise NotImplementedError - # Error handlers. + # Error handlers. - def set_exception_handler(self, handler): - raise NotImplementedError + def set_exception_handler(self, handler): + raise NotImplementedError - def default_exception_handler(self, context): - raise NotImplementedError + def default_exception_handler(self, context): + raise NotImplementedError - def call_exception_handler(self, context): - raise NotImplementedError + def call_exception_handler(self, context): + raise NotImplementedError - # Debug flag management. + # Debug flag management. - def get_debug(self): - raise NotImplementedError + def get_debug(self): + raise NotImplementedError - def set_debug(self, enabled): - raise NotImplementedError + def set_debug(self, enabled): + raise NotImplementedError -class AbstractEventLoopPolicy: - """Abstract policy for accessing the event loop.""" + class AbstractEventLoopPolicy(object): + """Abstract policy for accessing the event loop.""" - def get_event_loop(self): - """Get the event loop for the current context. + def get_event_loop(self): + """Get the event loop for the current context. - Returns an event loop object implementing the BaseEventLoop interface, - or raises an exception in case no event loop has been set for the - current context and the current policy does not specify to create one. + Returns an event loop object implementing the BaseEventLoop interface, + or raises an exception in case no event loop has been set for the + current context and the current policy does not specify to create one. - It should never return None.""" - raise NotImplementedError + It should never return None.""" + raise NotImplementedError - def set_event_loop(self, loop): - """Set the event loop for the current context to loop.""" - raise NotImplementedError + def set_event_loop(self, loop): + """Set the event loop for the current context to loop.""" + raise NotImplementedError - def new_event_loop(self): - """Create and return a new event loop object according to this - policy's rules. If there's need to set this loop as the event loop for - the current context, set_event_loop must be called explicitly.""" - raise NotImplementedError + def new_event_loop(self): + """Create and return a new event loop object according to this + policy's rules. If there's need to set this loop as the event loop for + the current context, set_event_loop must be called explicitly.""" + raise NotImplementedError - # Child processes handling (Unix only). + # Child processes handling (Unix only). - def get_child_watcher(self): - "Get the watcher for child processes." - raise NotImplementedError + def get_child_watcher(self): + "Get the watcher for child processes." + raise NotImplementedError - def set_child_watcher(self, watcher): - """Set the watcher for child processes.""" - raise NotImplementedError + def set_child_watcher(self, watcher): + """Set the watcher for child processes.""" + raise NotImplementedError class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy): diff --git a/trollius/futures.py b/trollius/futures.py index d06828a..80f6d11 100644 --- a/trollius/futures.py +++ b/trollius/futures.py @@ -5,13 +5,17 @@ __all__ = ['CancelledError', 'TimeoutError', 'Future', 'wrap_future', ] -import concurrent.futures._base import logging -import reprlib import sys import traceback +try: + import reprlib # Python 3 +except ImportError: + import repr as reprlib # Python 2 +from . import compat from . import events +from . import executor # States for Future. _PENDING = 'PENDING' @@ -21,9 +25,9 @@ _FINISHED = 'FINISHED' _PY34 = sys.version_info >= (3, 4) _PY35 = sys.version_info >= (3, 5) -Error = concurrent.futures._base.Error -CancelledError = concurrent.futures.CancelledError -TimeoutError = concurrent.futures.TimeoutError +Error = executor.Error +CancelledError = executor.CancelledError +TimeoutError = executor.TimeoutError STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging @@ -32,7 +36,7 @@ class InvalidStateError(Error): """The operation is not allowed in this state.""" -class _TracebackLogger: +class _TracebackLogger(object): """Helper to log a traceback upon destruction if not cleared. This solves a nasty problem with Futures and Tasks that have an @@ -112,7 +116,7 @@ class _TracebackLogger: self.loop.call_exception_handler({'message': msg}) -class Future: +class Future(object): """This class is *almost* compatible with concurrent.futures.Future. Differences: @@ -138,10 +142,14 @@ class Future: _blocking = False # proper use of future (yield vs yield from) + # Used by Python 2 to raise the exception with the original traceback + # in the exception() method in debug mode + _exception_tb = None + _log_traceback = False # Used for Python 3.4 and later _tb_logger = None # Used for Python 3.3 only - def __init__(self, *, loop=None): + def __init__(self, loop=None): """Initialize the future. The optional event_loop argument allows to explicitly set the event @@ -168,23 +176,23 @@ class Future: if size == 1: cb = format_cb(cb[0]) elif size == 2: - cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1])) + cb = '{0}, {1}'.format(format_cb(cb[0]), format_cb(cb[1])) elif size > 2: - cb = '{}, <{} more>, {}'.format(format_cb(cb[0]), - size-2, - format_cb(cb[-1])) + cb = '{0}, <{1} more>, {2}'.format(format_cb(cb[0]), + size-2, + format_cb(cb[-1])) return 'cb=[%s]' % cb def _repr_info(self): info = [self._state.lower()] if self._state == _FINISHED: if self._exception is not None: - info.append('exception={!r}'.format(self._exception)) + info.append('exception={0!r}'.format(self._exception)) else: # use reprlib to limit the length of the output, especially # for very long strings result = reprlib.repr(self._result) - info.append('result={}'.format(result)) + info.append('result={0}'.format(result)) if self._callbacks: info.append(self._format_callbacks()) if self._source_traceback: @@ -272,8 +280,13 @@ class Future: if self._tb_logger is not None: self._tb_logger.clear() self._tb_logger = None + exc_tb = self._exception_tb + self._exception_tb = None if self._exception is not None: - raise self._exception + if exc_tb is not None: + compat.reraise(type(self._exception), self._exception, exc_tb) + else: + raise self._exception return self._result def exception(self): @@ -292,6 +305,7 @@ class Future: if self._tb_logger is not None: self._tb_logger.clear() self._tb_logger = None + self._exception_tb = None return self._exception def add_done_callback(self, fn): @@ -334,31 +348,61 @@ class Future: InvalidStateError. """ if self._state != _PENDING: - raise InvalidStateError('{}: {!r}'.format(self._state, self)) + raise InvalidStateError('{0}: {1!r}'.format(self._state, self)) self._result = result self._state = _FINISHED self._schedule_callbacks() + def _get_exception_tb(self): + return self._exception_tb + def set_exception(self, exception): + self._set_exception_with_tb(exception, None) + + def _set_exception_with_tb(self, exception, exc_tb): """Mark the future done and set an exception. If the future is already done when this method is called, raises InvalidStateError. """ if self._state != _PENDING: - raise InvalidStateError('{}: {!r}'.format(self._state, self)) + raise InvalidStateError('{0}: {1!r}'.format(self._state, self)) if isinstance(exception, type): exception = exception() self._exception = exception + if exc_tb is not None: + self._exception_tb = exc_tb + exc_tb = None + elif self._loop.get_debug() and not compat.PY3: + self._exception_tb = sys.exc_info()[2] self._state = _FINISHED self._schedule_callbacks() if _PY34: self._log_traceback = True else: self._tb_logger = _TracebackLogger(self, exception) - # Arrange for the logger to be activated after all callbacks - # have had a chance to call result() or exception(). - self._loop.call_soon(self._tb_logger.activate) + if hasattr(exception, '__traceback__'): + # Python 3: exception contains a link to the traceback + + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + else: + if self._loop.get_debug(): + frame = sys._getframe(1) + tb = ['Traceback (most recent call last):\n'] + if self._exception_tb is not None: + tb += traceback.format_tb(self._exception_tb) + else: + tb += traceback.format_stack(frame) + tb += traceback.format_exception_only(type(exception), exception) + self._tb_logger.tb = tb + else: + self._tb_logger.tb = traceback.format_exception_only( + type(exception), + exception) + + self._tb_logger.exc = None # Truly internal methods. @@ -392,12 +436,18 @@ class Future: __await__ = __iter__ # make compatible with 'await' expression -def wrap_future(fut, *, loop=None): +if events.asyncio is not None: + # Accept also asyncio Future objects for interoperability + _FUTURE_CLASSES = (Future, events.asyncio.Future) +else: + _FUTURE_CLASSES = Future + +def wrap_future(fut, loop=None): """Wrap concurrent.futures.Future object.""" - if isinstance(fut, Future): + if isinstance(fut, _FUTURE_CLASSES): return fut - assert isinstance(fut, concurrent.futures.Future), \ - 'concurrent.futures.Future is expected, got {!r}'.format(fut) + assert isinstance(fut, executor.Future), \ + 'concurrent.futures.Future is expected, got {0!r}'.format(fut) if loop is None: loop = events.get_event_loop() new_future = Future(loop=loop) diff --git a/trollius/locks.py b/trollius/locks.py index b2e516b..ecbe3b3 100644 --- a/trollius/locks.py +++ b/trollius/locks.py @@ -7,7 +7,7 @@ import sys from . import events from . import futures -from .coroutines import coroutine +from .coroutines import coroutine, From, Return _PY35 = sys.version_info >= (3, 5) @@ -19,7 +19,7 @@ class _ContextManager: This enables the following idiom for acquiring and releasing a lock around a block: - with (yield from lock): + with (yield From(lock)): while failing loudly when accidentally using: @@ -43,7 +43,7 @@ class _ContextManager: self._lock = None # Crudely prevent reuse. -class _ContextManagerMixin: +class _ContextManagerMixin(object): def __enter__(self): raise RuntimeError( '"yield from" should be used as context manager expression') @@ -111,16 +111,16 @@ class Lock(_ContextManagerMixin): release() call resets the state to unlocked; first coroutine which is blocked in acquire() is being processed. - acquire() is a coroutine and should be called with 'yield from'. + acquire() is a coroutine and should be called with 'yield From'. - Locks also support the context management protocol. '(yield from lock)' + Locks also support the context management protocol. '(yield From(lock))' should be used as context manager expression. Usage: lock = Lock() ... - yield from lock + yield From(lock) try: ... finally: @@ -130,20 +130,20 @@ class Lock(_ContextManagerMixin): lock = Lock() ... - with (yield from lock): + with (yield From(lock)): ... Lock objects can be tested for locking state: if not lock.locked(): - yield from lock + yield From(lock) else: # lock is acquired ... """ - def __init__(self, *, loop=None): + def __init__(self, loop=None): self._waiters = collections.deque() self._locked = False if loop is not None: @@ -152,11 +152,11 @@ class Lock(_ContextManagerMixin): self._loop = events.get_event_loop() def __repr__(self): - res = super().__repr__() + res = super(Lock, self).__repr__() extra = 'locked' if self._locked else 'unlocked' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) def locked(self): """Return True if lock is acquired.""" @@ -171,14 +171,14 @@ class Lock(_ContextManagerMixin): """ if not self._waiters and not self._locked: self._locked = True - return True + raise Return(True) fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: - yield from fut + yield From(fut) self._locked = True - return True + raise Return(True) finally: self._waiters.remove(fut) @@ -204,7 +204,7 @@ class Lock(_ContextManagerMixin): raise RuntimeError('Lock is not acquired.') -class Event: +class Event(object): """Asynchronous equivalent to threading.Event. Class implementing event objects. An event manages a flag that can be set @@ -213,7 +213,7 @@ class Event: false. """ - def __init__(self, *, loop=None): + def __init__(self, loop=None): self._waiters = collections.deque() self._value = False if loop is not None: @@ -222,11 +222,11 @@ class Event: self._loop = events.get_event_loop() def __repr__(self): - res = super().__repr__() + res = super(Event, self).__repr__() extra = 'set' if self._value else 'unset' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) def is_set(self): """Return True if and only if the internal flag is true.""" @@ -259,13 +259,13 @@ class Event: set() to set the flag to true, then return True. """ if self._value: - return True + raise Return(True) fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: - yield from fut - return True + yield From(fut) + raise Return(True) finally: self._waiters.remove(fut) @@ -280,7 +280,7 @@ class Condition(_ContextManagerMixin): A new Lock object is created and used as the underlying lock. """ - def __init__(self, lock=None, *, loop=None): + def __init__(self, lock=None, loop=None): if loop is not None: self._loop = loop else: @@ -300,11 +300,11 @@ class Condition(_ContextManagerMixin): self._waiters = collections.deque() def __repr__(self): - res = super().__repr__() + res = super(Condition, self).__repr__() extra = 'locked' if self.locked() else 'unlocked' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) @coroutine def wait(self): @@ -326,13 +326,13 @@ class Condition(_ContextManagerMixin): fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: - yield from fut - return True + yield From(fut) + raise Return(True) finally: self._waiters.remove(fut) finally: - yield from self.acquire() + yield From(self.acquire()) @coroutine def wait_for(self, predicate): @@ -344,9 +344,9 @@ class Condition(_ContextManagerMixin): """ result = predicate() while not result: - yield from self.wait() + yield From(self.wait()) result = predicate() - return result + raise Return(result) def notify(self, n=1): """By default, wake up one coroutine waiting on this condition, if any. @@ -396,7 +396,7 @@ class Semaphore(_ContextManagerMixin): ValueError is raised. """ - def __init__(self, value=1, *, loop=None): + def __init__(self, value=1, loop=None): if value < 0: raise ValueError("Semaphore initial value must be >= 0") self._value = value @@ -407,12 +407,12 @@ class Semaphore(_ContextManagerMixin): self._loop = events.get_event_loop() def __repr__(self): - res = super().__repr__() - extra = 'locked' if self.locked() else 'unlocked,value:{}'.format( + res = super(Semaphore, self).__repr__() + extra = 'locked' if self.locked() else 'unlocked,value:{0}'.format( self._value) if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) def locked(self): """Returns True if semaphore can not be acquired immediately.""" @@ -430,14 +430,14 @@ class Semaphore(_ContextManagerMixin): """ if not self._waiters and self._value > 0: self._value -= 1 - return True + raise Return(True) fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: - yield from fut + yield From(fut) self._value -= 1 - return True + raise Return(True) finally: self._waiters.remove(fut) @@ -460,11 +460,11 @@ class BoundedSemaphore(Semaphore): above the initial value. """ - def __init__(self, value=1, *, loop=None): + def __init__(self, value=1, loop=None): self._bound_value = value - super().__init__(value, loop=loop) + super(BoundedSemaphore, self).__init__(value, loop=loop) def release(self): if self._value >= self._bound_value: raise ValueError('BoundedSemaphore released too many times') - super().release() + super(BoundedSemaphore, self).release() diff --git a/trollius/proactor_events.py b/trollius/proactor_events.py index 9c2b8f1..49d8bc3 100644 --- a/trollius/proactor_events.py +++ b/trollius/proactor_events.py @@ -16,6 +16,9 @@ from . import futures from . import sslproto from . import transports from .log import logger +from .compat import flatten_bytes +from .py33_exceptions import (BrokenPipeError, + ConnectionAbortedError, ConnectionResetError) class _ProactorBasePipeTransport(transports._FlowControlMixin, @@ -24,7 +27,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): - super().__init__(extra, loop) + super(_ProactorBasePipeTransport, self).__init__(extra, loop) self._set_extra(sock) self._sock = sock self._protocol = protocol @@ -143,7 +146,8 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): - super().__init__(loop, sock, protocol, waiter, extra, server) + super(_ProactorReadPipeTransport, self).__init__(loop, sock, protocol, + waiter, extra, server) self._paused = False self._loop.call_soon(self._loop_reading) @@ -220,9 +224,7 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, """Transport for write pipes.""" def write(self, data): - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) + data = flatten_bytes(data) if self._eof_written: raise RuntimeError('write_eof() already called') @@ -301,7 +303,7 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): def __init__(self, *args, **kw): - super().__init__(*args, **kw) + super(_ProactorWritePipeTransport, self).__init__(*args, **kw) self._read_fut = self._loop._proactor.recv(self._sock, 16) self._read_fut.add_done_callback(self._pipe_closed) @@ -368,7 +370,7 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport, class BaseProactorEventLoop(base_events.BaseEventLoop): def __init__(self, proactor): - super().__init__() + super(BaseProactorEventLoop, self).__init__() logger.debug('Using proactor: %s', proactor.__class__.__name__) self._proactor = proactor self._selector = proactor # convenient alias @@ -383,7 +385,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): extra, server) def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, + server_side=False, server_hostname=None, extra=None, server=None): if not sslproto._is_sslproto_available(): raise NotImplementedError("Proactor event loop requires Python 3.5" @@ -427,7 +429,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): self._selector = None # Close the event loop - super().close() + super(BaseProactorEventLoop, self).close() def sock_recv(self, sock, n): return self._proactor.recv(sock, n) diff --git a/trollius/protocols.py b/trollius/protocols.py index 80fcac9..2c18287 100644 --- a/trollius/protocols.py +++ b/trollius/protocols.py @@ -4,7 +4,7 @@ __all__ = ['BaseProtocol', 'Protocol', 'DatagramProtocol', 'SubprocessProtocol'] -class BaseProtocol: +class BaseProtocol(object): """Common base class for protocol interfaces. Usually user implements protocols that derived from BaseProtocol diff --git a/trollius/queues.py b/trollius/queues.py index ed11662..d856305 100644 --- a/trollius/queues.py +++ b/trollius/queues.py @@ -9,6 +9,7 @@ import heapq from . import events from . import futures from . import locks +from .coroutines import From, Return from .tasks import coroutine @@ -26,7 +27,7 @@ class QueueFull(Exception): pass -class Queue: +class Queue(object): """A queue, useful for coordinating producer and consumer coroutines. If maxsize is less than or equal to zero, the queue size is infinite. If it @@ -38,7 +39,7 @@ class Queue: interrupted between calling qsize() and doing an operation on the Queue. """ - def __init__(self, maxsize=0, *, loop=None): + def __init__(self, maxsize=0, loop=None): if loop is None: self._loop = events.get_event_loop() else: @@ -73,22 +74,22 @@ class Queue: self._finished.clear() def __repr__(self): - return '<{} at {:#x} {}>'.format( + return '<{0} at {1:#x} {2}>'.format( type(self).__name__, id(self), self._format()) def __str__(self): - return '<{} {}>'.format(type(self).__name__, self._format()) + return '<{0} {1}>'.format(type(self).__name__, self._format()) def _format(self): - result = 'maxsize={!r}'.format(self._maxsize) + result = 'maxsize={0!r}'.format(self._maxsize) if getattr(self, '_queue', None): - result += ' _queue={!r}'.format(list(self._queue)) + result += ' _queue={0!r}'.format(list(self._queue)) if self._getters: - result += ' _getters[{}]'.format(len(self._getters)) + result += ' _getters[{0}]'.format(len(self._getters)) if self._putters: - result += ' _putters[{}]'.format(len(self._putters)) + result += ' _putters[{0}]'.format(len(self._putters)) if self._unfinished_tasks: - result += ' tasks={}'.format(self._unfinished_tasks) + result += ' tasks={0}'.format(self._unfinished_tasks) return result def _consume_done_getters(self): @@ -149,7 +150,7 @@ class Queue: waiter = futures.Future(loop=self._loop) self._putters.append((item, waiter)) - yield from waiter + yield From(waiter) else: self.__put_internal(item) @@ -195,15 +196,16 @@ class Queue: # ChannelTest.test_wait. self._loop.call_soon(putter._set_result_unless_cancelled, None) - return self._get() + raise Return(self._get()) elif self.qsize(): - return self._get() + raise Return(self._get()) else: waiter = futures.Future(loop=self._loop) self._getters.append(waiter) - return (yield from waiter) + result = yield From(waiter) + raise Return(result) def get_nowait(self): """Remove and return an item from the queue. @@ -257,7 +259,7 @@ class Queue: When the count of unfinished tasks drops to zero, join() unblocks. """ if self._unfinished_tasks > 0: - yield from self._finished.wait() + yield From(self._finished.wait()) class PriorityQueue(Queue): diff --git a/trollius/selector_events.py b/trollius/selector_events.py index 7c5b9b5..ec14974 100644 --- a/trollius/selector_events.py +++ b/trollius/selector_events.py @@ -14,6 +14,8 @@ import sys import warnings try: import ssl + from .py3_ssl import (wrap_ssl_error, SSLContext, SSLWantReadError, + SSLWantWriteError) except ImportError: # pragma: no cover ssl = None @@ -24,8 +26,26 @@ from . import futures from . import selectors from . import transports from . import sslproto +from .compat import flatten_bytes from .coroutines import coroutine from .log import logger +from .py33_exceptions import (wrap_error, + BlockingIOError, InterruptedError, ConnectionAbortedError, BrokenPipeError, + ConnectionResetError) + +# On Mac OS 10.6 with Python 2.6.1 or OpenIndiana 148 with Python 2.6.4, +# _SelectorSslTransport._read_ready() hangs if the socket has no data. +# Example: test_events.test_create_server_ssl() +_SSL_REQUIRES_SELECT = (sys.version_info < (2, 6, 6)) +if _SSL_REQUIRES_SELECT: + import select + + +def _get_socket_error(sock, address): + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise OSError(err, 'Connect call failed %s' % (address,)) def _test_selector_event(selector, fd, event): @@ -46,7 +66,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): """ def __init__(self, selector=None): - super().__init__() + super(BaseSelectorEventLoop, self).__init__() if selector is None: selector = selectors.DefaultSelector() @@ -54,13 +74,13 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): self._selector = selector self._make_self_pipe() - def _make_socket_transport(self, sock, protocol, waiter=None, *, + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, server=None): return _SelectorSocketTransport(self, sock, protocol, waiter, extra, server) def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, + server_side=False, server_hostname=None, extra=None, server=None): if not sslproto._is_sslproto_available(): return self._make_legacy_ssl_transport( @@ -75,7 +95,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): return ssl_protocol._app_transport def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext, - waiter, *, + waiter, server_side=False, server_hostname=None, extra=None, server=None): # Use the legacy API: SSL_write, SSL_read, etc. The legacy API is used @@ -95,7 +115,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): if self.is_closed(): return self._close_self_pipe() - super().close() + super(BaseSelectorEventLoop, self).close() if self._selector is not None: self._selector.close() self._selector = None @@ -125,7 +145,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): def _read_from_self(self): while True: try: - data = self._ssock.recv(4096) + data = wrap_error(self._ssock.recv, 4096) if not data: break self._process_self_data(data) @@ -143,7 +163,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): csock = self._csock if csock is not None: try: - csock.send(b'\0') + wrap_error(csock.send, b'\0') except OSError: if self._debug: logger.debug("Fail to write a null byte into the " @@ -158,14 +178,14 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): def _accept_connection(self, protocol_factory, sock, sslcontext=None, server=None): try: - conn, addr = sock.accept() + conn, addr = wrap_error(sock.accept) if self._debug: logger.debug("%r got a new connection from %r: %r", server, addr, conn) conn.setblocking(False) except (BlockingIOError, InterruptedError, ConnectionAbortedError): pass # False alarm. - except OSError as exc: + except socket.error as exc: # There's nowhere to send the error, so just log it. if exc.errno in (errno.EMFILE, errno.ENFILE, errno.ENOBUFS, errno.ENOMEM): @@ -331,7 +351,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): if fut.cancelled(): return try: - data = sock.recv(n) + data = wrap_error(sock.recv, n) except (BlockingIOError, InterruptedError): self.add_reader(fd, self._sock_recv, fut, True, sock, n) except Exception as exc: @@ -368,7 +388,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): return try: - n = sock.send(data) + n = wrap_error(sock.send, data) except (BlockingIOError, InterruptedError): n = 0 except Exception as exc: @@ -408,7 +428,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): def _sock_connect(self, fut, sock, address): fd = sock.fileno() try: - sock.connect(address) + wrap_error(sock.connect, address) except (BlockingIOError, InterruptedError): # Issue #23618: When the C function connect() fails with EINTR, the # connection runs in background. We have to wait until the socket @@ -430,10 +450,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): return try: - err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - # Jump to any except clause below. - raise OSError(err, 'Connect call failed %s' % (address,)) + wrap_error(_get_socket_error, sock, address) except (BlockingIOError, InterruptedError): # socket is still registered, the callback will be retried later pass @@ -465,7 +482,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): if fut.cancelled(): return try: - conn, address = sock.accept() + conn, address = wrap_error(sock.accept) conn.setblocking(False) except (BlockingIOError, InterruptedError): self.add_reader(fd, self._sock_accept, fut, True, sock) @@ -506,7 +523,7 @@ class _SelectorTransport(transports._FlowControlMixin, _sock = None def __init__(self, loop, sock, protocol, extra=None, server=None): - super().__init__(extra, loop) + super(_SelectorTransport, self).__init__(extra, loop) self._extra['socket'] = sock self._extra['sockname'] = sock.getsockname() if 'peername' not in self._extra: @@ -593,7 +610,7 @@ class _SelectorTransport(transports._FlowControlMixin, if self._conn_lost: return if self._buffer: - self._buffer.clear() + del self._buffer[:] self._loop.remove_writer(self._sock_fd) if not self._closing: self._closing = True @@ -623,7 +640,7 @@ class _SelectorSocketTransport(_SelectorTransport): def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): - super().__init__(loop, sock, protocol, extra, server) + super(_SelectorSocketTransport, self).__init__(loop, sock, protocol, extra, server) self._eof = False self._paused = False @@ -657,7 +674,7 @@ class _SelectorSocketTransport(_SelectorTransport): def _read_ready(self): try: - data = self._sock.recv(self.max_size) + data = wrap_error(self._sock.recv, self.max_size) except (BlockingIOError, InterruptedError): pass except Exception as exc: @@ -678,9 +695,7 @@ class _SelectorSocketTransport(_SelectorTransport): self.close() def write(self, data): - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) + data = flatten_bytes(data) if self._eof: raise RuntimeError('Cannot call write() after write_eof()') if not data: @@ -695,7 +710,7 @@ class _SelectorSocketTransport(_SelectorTransport): if not self._buffer: # Optimization: try to send now. try: - n = self._sock.send(data) + n = wrap_error(self._sock.send, data) except (BlockingIOError, InterruptedError): pass except Exception as exc: @@ -715,13 +730,14 @@ class _SelectorSocketTransport(_SelectorTransport): def _write_ready(self): assert self._buffer, 'Data should not be empty' + data = flatten_bytes(self._buffer) try: - n = self._sock.send(self._buffer) + n = wrap_error(self._sock.send, data) except (BlockingIOError, InterruptedError): pass except Exception as exc: self._loop.remove_writer(self._sock_fd) - self._buffer.clear() + del self._buffer[:] self._fatal_error(exc, 'Fatal write error on socket transport') else: if n: @@ -766,7 +782,7 @@ class _SelectorSslTransport(_SelectorTransport): wrap_kwargs['server_hostname'] = server_hostname sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) - super().__init__(loop, sslsock, protocol, extra, server) + super(_SelectorSslTransport, self).__init__(loop, sslsock, protocol, extra, server) # the protocol connection is only made after the SSL handshake self._protocol_connected = False @@ -797,12 +813,12 @@ class _SelectorSslTransport(_SelectorTransport): def _on_handshake(self, start_time): try: - self._sock.do_handshake() - except ssl.SSLWantReadError: + wrap_ssl_error(self._sock.do_handshake) + except SSLWantReadError: self._loop.add_reader(self._sock_fd, self._on_handshake, start_time) return - except ssl.SSLWantWriteError: + except SSLWantWriteError: self._loop.add_writer(self._sock_fd, self._on_handshake, start_time) return @@ -842,8 +858,9 @@ class _SelectorSslTransport(_SelectorTransport): # Add extra info that becomes available after handshake. self._extra.update(peercert=peercert, cipher=self._sock.cipher(), - compression=self._sock.compression(), ) + if hasattr(self._sock, 'compression'): + self._extra['compression'] = self._sock.compression() self._read_wants_write = False self._write_wants_read = False @@ -883,6 +900,9 @@ class _SelectorSslTransport(_SelectorTransport): if self._loop.get_debug(): logger.debug("%r resumes reading", self) + def _sock_recv(self): + return wrap_ssl_error(self._sock.recv, self.max_size) + def _read_ready(self): if self._write_wants_read: self._write_wants_read = False @@ -892,10 +912,16 @@ class _SelectorSslTransport(_SelectorTransport): self._loop.add_writer(self._sock_fd, self._write_ready) try: - data = self._sock.recv(self.max_size) - except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): + if _SSL_REQUIRES_SELECT: + rfds = (self._sock.fileno(),) + rfds = select.select(rfds, (), (), 0.0)[0] + if not rfds: + # False alarm. + return + data = wrap_error(self._sock_recv) + except (BlockingIOError, InterruptedError, SSLWantReadError): pass - except ssl.SSLWantWriteError: + except SSLWantWriteError: self._read_wants_write = True self._loop.remove_reader(self._sock_fd) self._loop.add_writer(self._sock_fd, self._write_ready) @@ -924,17 +950,18 @@ class _SelectorSslTransport(_SelectorTransport): self._loop.add_reader(self._sock_fd, self._read_ready) if self._buffer: + data = flatten_bytes(self._buffer) try: - n = self._sock.send(self._buffer) - except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError): + n = wrap_error(self._sock.send, data) + except (BlockingIOError, InterruptedError, SSLWantWriteError): n = 0 - except ssl.SSLWantReadError: + except SSLWantReadError: n = 0 self._loop.remove_writer(self._sock_fd) self._write_wants_read = True except Exception as exc: self._loop.remove_writer(self._sock_fd) - self._buffer.clear() + del self._buffer[:] self._fatal_error(exc, 'Fatal write error on SSL transport') return @@ -949,9 +976,7 @@ class _SelectorSslTransport(_SelectorTransport): self._call_connection_lost(None) def write(self, data): - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) + data = flatten_bytes(data) if not data: return @@ -978,7 +1003,8 @@ class _SelectorDatagramTransport(_SelectorTransport): def __init__(self, loop, sock, protocol, address=None, waiter=None, extra=None): - super().__init__(loop, sock, protocol, extra) + super(_SelectorDatagramTransport, self).__init__(loop, sock, + protocol, extra) self._address = address self._loop.call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called @@ -993,7 +1019,7 @@ class _SelectorDatagramTransport(_SelectorTransport): def _read_ready(self): try: - data, addr = self._sock.recvfrom(self.max_size) + data, addr = wrap_error(self._sock.recvfrom, self.max_size) except (BlockingIOError, InterruptedError): pass except OSError as exc: @@ -1004,9 +1030,7 @@ class _SelectorDatagramTransport(_SelectorTransport): self._protocol.datagram_received(data, addr) def sendto(self, data, addr=None): - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) + data = flatten_bytes(data) if not data: return @@ -1024,9 +1048,9 @@ class _SelectorDatagramTransport(_SelectorTransport): # Attempt to send it right away first. try: if self._address: - self._sock.send(data) + wrap_error(self._sock.send, data) else: - self._sock.sendto(data, addr) + wrap_error(self._sock.sendto, data, addr) return except (BlockingIOError, InterruptedError): self._loop.add_writer(self._sock_fd, self._sendto_ready) @@ -1047,9 +1071,9 @@ class _SelectorDatagramTransport(_SelectorTransport): data, addr = self._buffer.popleft() try: if self._address: - self._sock.send(data) + wrap_error(self._sock.send, data) else: - self._sock.sendto(data, addr) + wrap_error(self._sock.sendto, data, addr) except (BlockingIOError, InterruptedError): self._buffer.appendleft((data, addr)) # Try again later. break diff --git a/trollius/selectors.py b/trollius/selectors.py index 6d569c3..d2f822c 100644 --- a/trollius/selectors.py +++ b/trollius/selectors.py @@ -11,6 +11,9 @@ import math import select import sys +from .py33_exceptions import wrap_error, InterruptedError +from .compat import integer_types + # generic events, that must be mapped to implementation-specific ones EVENT_READ = (1 << 0) @@ -29,16 +32,16 @@ def _fileobj_to_fd(fileobj): Raises: ValueError if the object is invalid """ - if isinstance(fileobj, int): + if isinstance(fileobj, integer_types): fd = fileobj else: try: fd = int(fileobj.fileno()) except (AttributeError, TypeError, ValueError): raise ValueError("Invalid file object: " - "{!r}".format(fileobj)) from None + "{0!r}".format(fileobj)) if fd < 0: - raise ValueError("Invalid file descriptor: {}".format(fd)) + raise ValueError("Invalid file descriptor: {0}".format(fd)) return fd @@ -61,13 +64,13 @@ class _SelectorMapping(Mapping): fd = self._selector._fileobj_lookup(fileobj) return self._selector._fd_to_key[fd] except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None + raise KeyError("{0!r} is not registered".format(fileobj)) def __iter__(self): return iter(self._selector._fd_to_key) -class BaseSelector(metaclass=ABCMeta): +class BaseSelector(object): """Selector abstract base class. A selector supports registering file objects to be monitored for specific @@ -81,6 +84,7 @@ class BaseSelector(metaclass=ABCMeta): depending on the platform. The default `Selector` class uses the most efficient implementation on the current platform. """ + __metaclass__ = ABCMeta @abstractmethod def register(self, fileobj, events, data=None): @@ -179,7 +183,7 @@ class BaseSelector(metaclass=ABCMeta): try: return mapping[fileobj] except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None + raise KeyError("{0!r} is not registered".format(fileobj)) @abstractmethod def get_map(self): @@ -223,12 +227,12 @@ class _BaseSelectorImpl(BaseSelector): def register(self, fileobj, events, data=None): if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): - raise ValueError("Invalid events: {!r}".format(events)) + raise ValueError("Invalid events: {0!r}".format(events)) key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) if key.fd in self._fd_to_key: - raise KeyError("{!r} (FD {}) is already registered" + raise KeyError("{0!r} (FD {1}) is already registered" .format(fileobj, key.fd)) self._fd_to_key[key.fd] = key @@ -238,7 +242,7 @@ class _BaseSelectorImpl(BaseSelector): try: key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None + raise KeyError("{0!r} is not registered".format(fileobj)) return key def modify(self, fileobj, events, data=None): @@ -246,7 +250,7 @@ class _BaseSelectorImpl(BaseSelector): try: key = self._fd_to_key[self._fileobj_lookup(fileobj)] except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None + raise KeyError("{0!r} is not registered".format(fileobj)) if events != key.events: self.unregister(fileobj) key = self.register(fileobj, events, data) @@ -282,12 +286,12 @@ class SelectSelector(_BaseSelectorImpl): """Select-based selector.""" def __init__(self): - super().__init__() + super(SelectSelector, self).__init__() self._readers = set() self._writers = set() def register(self, fileobj, events, data=None): - key = super().register(fileobj, events, data) + key = super(SelectSelector, self).register(fileobj, events, data) if events & EVENT_READ: self._readers.add(key.fd) if events & EVENT_WRITE: @@ -295,7 +299,7 @@ class SelectSelector(_BaseSelectorImpl): return key def unregister(self, fileobj): - key = super().unregister(fileobj) + key = super(SelectSelector, self).unregister(fileobj) self._readers.discard(key.fd) self._writers.discard(key.fd) return key @@ -311,7 +315,8 @@ class SelectSelector(_BaseSelectorImpl): timeout = None if timeout is None else max(timeout, 0) ready = [] try: - r, w, _ = self._select(self._readers, self._writers, [], timeout) + r, w, _ = wrap_error(self._select, + self._readers, self._writers, [], timeout) except InterruptedError: return ready r = set(r) @@ -335,11 +340,11 @@ if hasattr(select, 'poll'): """Poll-based selector.""" def __init__(self): - super().__init__() + super(PollSelector, self).__init__() self._poll = select.poll() def register(self, fileobj, events, data=None): - key = super().register(fileobj, events, data) + key = super(PollSelector, self).register(fileobj, events, data) poll_events = 0 if events & EVENT_READ: poll_events |= select.POLLIN @@ -349,7 +354,7 @@ if hasattr(select, 'poll'): return key def unregister(self, fileobj): - key = super().unregister(fileobj) + key = super(PollSelector, self).unregister(fileobj) self._poll.unregister(key.fd) return key @@ -361,10 +366,10 @@ if hasattr(select, 'poll'): else: # poll() has a resolution of 1 millisecond, round away from # zero to wait *at least* timeout seconds. - timeout = math.ceil(timeout * 1e3) + timeout = int(math.ceil(timeout * 1e3)) ready = [] try: - fd_event_list = self._poll.poll(timeout) + fd_event_list = wrap_error(self._poll.poll, timeout) except InterruptedError: return ready for fd, event in fd_event_list: @@ -386,14 +391,14 @@ if hasattr(select, 'epoll'): """Epoll-based selector.""" def __init__(self): - super().__init__() + super(EpollSelector, self).__init__() self._epoll = select.epoll() def fileno(self): return self._epoll.fileno() def register(self, fileobj, events, data=None): - key = super().register(fileobj, events, data) + key = super(EpollSelector, self).register(fileobj, events, data) epoll_events = 0 if events & EVENT_READ: epoll_events |= select.EPOLLIN @@ -403,7 +408,7 @@ if hasattr(select, 'epoll'): return key def unregister(self, fileobj): - key = super().unregister(fileobj) + key = super(EpollSelector, self).unregister(fileobj) try: self._epoll.unregister(key.fd) except OSError: @@ -429,7 +434,7 @@ if hasattr(select, 'epoll'): ready = [] try: - fd_event_list = self._epoll.poll(timeout, max_ev) + fd_event_list = wrap_error(self._epoll.poll, timeout, max_ev) except InterruptedError: return ready for fd, event in fd_event_list: @@ -446,7 +451,7 @@ if hasattr(select, 'epoll'): def close(self): self._epoll.close() - super().close() + super(EpollSelector, self).close() if hasattr(select, 'devpoll'): @@ -455,14 +460,14 @@ if hasattr(select, 'devpoll'): """Solaris /dev/poll selector.""" def __init__(self): - super().__init__() + super(DevpollSelector, self).__init__() self._devpoll = select.devpoll() def fileno(self): return self._devpoll.fileno() def register(self, fileobj, events, data=None): - key = super().register(fileobj, events, data) + key = super(DevpollSelector, self).register(fileobj, events, data) poll_events = 0 if events & EVENT_READ: poll_events |= select.POLLIN @@ -472,7 +477,7 @@ if hasattr(select, 'devpoll'): return key def unregister(self, fileobj): - key = super().unregister(fileobj) + key = super(DevpollSelector, self).unregister(fileobj) self._devpoll.unregister(key.fd) return key @@ -504,7 +509,7 @@ if hasattr(select, 'devpoll'): def close(self): self._devpoll.close() - super().close() + super(DevpollSelector, self).close() if hasattr(select, 'kqueue'): @@ -513,14 +518,14 @@ if hasattr(select, 'kqueue'): """Kqueue-based selector.""" def __init__(self): - super().__init__() + super(KqueueSelector, self).__init__() self._kqueue = select.kqueue() def fileno(self): return self._kqueue.fileno() def register(self, fileobj, events, data=None): - key = super().register(fileobj, events, data) + key = super(KqueueSelector, self).register(fileobj, events, data) if events & EVENT_READ: kev = select.kevent(key.fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) @@ -532,7 +537,7 @@ if hasattr(select, 'kqueue'): return key def unregister(self, fileobj): - key = super().unregister(fileobj) + key = super(KqueueSelector, self).unregister(fileobj) if key.events & EVENT_READ: kev = select.kevent(key.fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) @@ -557,7 +562,8 @@ if hasattr(select, 'kqueue'): max_ev = len(self._fd_to_key) ready = [] try: - kev_list = self._kqueue.control(None, max_ev, timeout) + kev_list = wrap_error(self._kqueue.control, + None, max_ev, timeout) except InterruptedError: return ready for kev in kev_list: @@ -576,7 +582,7 @@ if hasattr(select, 'kqueue'): def close(self): self._kqueue.close() - super().close() + super(KqueueSelector, self).close() # Choose the best implementation, roughly: diff --git a/trollius/sslproto.py b/trollius/sslproto.py index 235855e..5f4920a 100644 --- a/trollius/sslproto.py +++ b/trollius/sslproto.py @@ -9,6 +9,7 @@ except ImportError: # pragma: no cover from . import protocols from . import transports from .log import logger +from .py3_ssl import BACKPORT_SSL_CONTEXT def _create_transport_context(server_side, server_hostname): @@ -26,10 +27,11 @@ def _create_transport_context(server_side, server_hostname): else: # Fallback for Python 3.3. sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.set_default_verify_paths() - sslcontext.verify_mode = ssl.CERT_REQUIRED + if not BACKPORT_SSL_CONTEXT: + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED return sslcontext @@ -43,6 +45,11 @@ _DO_HANDSHAKE = "DO_HANDSHAKE" _WRAPPED = "WRAPPED" _SHUTDOWN = "SHUTDOWN" +if hasattr(ssl, 'CertificateError'): + _SSL_ERRORS = (ssl.SSLError, ssl.CertificateError) +else: + _SSL_ERRORS = ssl.SSLError + class _SSLPipe(object): """An SSL "Pipe". @@ -224,7 +231,7 @@ class _SSLPipe(object): elif self._state == _UNWRAPPED: # Drain possible plaintext data after close_notify. appdata.append(self._incoming.read()) - except (ssl.SSLError, ssl.CertificateError) as exc: + except _SSL_ERRORS as exc: if getattr(exc, 'errno', None) not in ( ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, ssl.SSL_ERROR_SYSCALL): @@ -569,7 +576,8 @@ class SSLProtocol(protocols.Protocol): ssl.match_hostname(peercert, self._server_hostname) except BaseException as exc: if self._loop.get_debug(): - if isinstance(exc, ssl.CertificateError): + if (hasattr(ssl, 'CertificateError') + and isinstance(exc, ssl.CertificateError)): logger.warning("%r: SSL handshake failed " "on verifying the certificate", self, exc_info=True) diff --git a/trollius/streams.py b/trollius/streams.py index 176c65e..c235c5a 100644 --- a/trollius/streams.py +++ b/trollius/streams.py @@ -15,7 +15,8 @@ from . import coroutines from . import events from . import futures from . import protocols -from .coroutines import coroutine +from .coroutines import coroutine, From, Return +from .py33_exceptions import ConnectionResetError from .log import logger @@ -38,7 +39,7 @@ class IncompleteReadError(EOFError): @coroutine -def open_connection(host=None, port=None, *, +def open_connection(host=None, port=None, loop=None, limit=_DEFAULT_LIMIT, **kwds): """A wrapper for create_connection() returning a (reader, writer) pair. @@ -61,14 +62,14 @@ def open_connection(host=None, port=None, *, loop = events.get_event_loop() reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, loop=loop) - transport, _ = yield from loop.create_connection( - lambda: protocol, host, port, **kwds) + transport, _ = yield From(loop.create_connection( + lambda: protocol, host, port, **kwds)) writer = StreamWriter(transport, protocol, reader, loop) - return reader, writer + raise Return(reader, writer) @coroutine -def start_server(client_connected_cb, host=None, port=None, *, +def start_server(client_connected_cb, host=None, port=None, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Start a socket server, call back for each client connected. @@ -100,28 +101,29 @@ def start_server(client_connected_cb, host=None, port=None, *, loop=loop) return protocol - return (yield from loop.create_server(factory, host, port, **kwds)) + server = yield From(loop.create_server(factory, host, port, **kwds)) + raise Return(server) if hasattr(socket, 'AF_UNIX'): # UNIX Domain Sockets are supported on this platform @coroutine - def open_unix_connection(path=None, *, + def open_unix_connection(path=None, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `open_connection` but works with UNIX Domain Sockets.""" if loop is None: loop = events.get_event_loop() reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, loop=loop) - transport, _ = yield from loop.create_unix_connection( - lambda: protocol, path, **kwds) + transport, _ = yield From(loop.create_unix_connection( + lambda: protocol, path, **kwds)) writer = StreamWriter(transport, protocol, reader, loop) - return reader, writer + raise Return(reader, writer) @coroutine - def start_unix_server(client_connected_cb, path=None, *, + def start_unix_server(client_connected_cb, path=None, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `start_server` but works with UNIX Domain Sockets.""" if loop is None: @@ -133,7 +135,8 @@ if hasattr(socket, 'AF_UNIX'): loop=loop) return protocol - return (yield from loop.create_unix_server(factory, path, **kwds)) + server = (yield From(loop.create_unix_server(factory, path, **kwds))) + raise Return(server) class FlowControlMixin(protocols.Protocol): @@ -199,7 +202,7 @@ class FlowControlMixin(protocols.Protocol): assert waiter is None or waiter.cancelled() waiter = futures.Future(loop=self._loop) self._drain_waiter = waiter - yield from waiter + yield From(waiter) class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): @@ -212,7 +215,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): """ def __init__(self, stream_reader, client_connected_cb=None, loop=None): - super().__init__(loop=loop) + super(StreamReaderProtocol, self).__init__(loop=loop) self._stream_reader = stream_reader self._stream_writer = None self._client_connected_cb = client_connected_cb @@ -233,7 +236,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): self._stream_reader.feed_eof() else: self._stream_reader.set_exception(exc) - super().connection_lost(exc) + super(StreamReaderProtocol, self).connection_lost(exc) def data_received(self, data): self._stream_reader.feed_data(data) @@ -242,7 +245,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): self._stream_reader.feed_eof() -class StreamWriter: +class StreamWriter(object): """Wraps a Transport. This exposes write(), writelines(), [can_]write_eof(), @@ -295,16 +298,16 @@ class StreamWriter: The intended use is to write w.write(data) - yield from w.drain() + yield From(w.drain()) """ if self._reader is not None: exc = self._reader.exception() if exc is not None: raise exc - yield from self._protocol._drain_helper() + yield From(self._protocol._drain_helper()) -class StreamReader: +class StreamReader(object): def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # The line length limit is a security feature; @@ -391,9 +394,16 @@ class StreamReader: raise RuntimeError('%s() called while another coroutine is ' 'already waiting for incoming data' % func_name) + # In asyncio, there is no need to recheck if we got data or EOF thanks + # to "yield from". In trollius, a StreamReader method can be called + # after the _wait_for_data() coroutine is scheduled and before it is + # really executed. + if self._buffer or self._eof: + return + self._waiter = futures.Future(loop=self._loop) try: - yield from self._waiter + yield From(self._waiter) finally: self._waiter = None @@ -410,7 +420,7 @@ class StreamReader: ichar = self._buffer.find(b'\n') if ichar < 0: line.extend(self._buffer) - self._buffer.clear() + del self._buffer[:] else: ichar += 1 line.extend(self._buffer[:ichar]) @@ -425,10 +435,10 @@ class StreamReader: break if not_enough: - yield from self._wait_for_data('readline') + yield From(self._wait_for_data('readline')) self._maybe_resume_transport() - return bytes(line) + raise Return(bytes(line)) @coroutine def read(self, n=-1): @@ -436,7 +446,7 @@ class StreamReader: raise self._exception if not n: - return b'' + raise Return(b'') if n < 0: # This used to just loop creating a new waiter hoping to @@ -445,25 +455,25 @@ class StreamReader: # bytes. So just call self.read(self._limit) until EOF. blocks = [] while True: - block = yield from self.read(self._limit) + block = yield From(self.read(self._limit)) if not block: break blocks.append(block) - return b''.join(blocks) + raise Return(b''.join(blocks)) else: if not self._buffer and not self._eof: - yield from self._wait_for_data('read') + yield From(self._wait_for_data('read')) if n < 0 or len(self._buffer) <= n: data = bytes(self._buffer) - self._buffer.clear() + del self._buffer[:] else: # n > 0 and len(self._buffer) > n data = bytes(self._buffer[:n]) del self._buffer[:n] self._maybe_resume_transport() - return data + raise Return(data) @coroutine def readexactly(self, n): @@ -479,14 +489,14 @@ class StreamReader: blocks = [] while n > 0: - block = yield from self.read(n) + block = yield From(self.read(n)) if not block: partial = b''.join(blocks) raise IncompleteReadError(partial, len(partial) + n) blocks.append(block) n -= len(block) - return b''.join(blocks) + raise Return(b''.join(blocks)) if _PY35: @coroutine diff --git a/trollius/subprocess.py b/trollius/subprocess.py index 4600a9f..2f1becf 100644 --- a/trollius/subprocess.py +++ b/trollius/subprocess.py @@ -8,13 +8,16 @@ from . import futures from . import protocols from . import streams from . import tasks -from .coroutines import coroutine +from .coroutines import coroutine, From, Return +from .py33_exceptions import (BrokenPipeError, ConnectionResetError, + ProcessLookupError) from .log import logger PIPE = subprocess.PIPE STDOUT = subprocess.STDOUT -DEVNULL = subprocess.DEVNULL +if hasattr(subprocess, 'DEVNULL'): + DEVNULL = subprocess.DEVNULL class SubprocessStreamProtocol(streams.FlowControlMixin, @@ -22,7 +25,7 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, """Like StreamReaderProtocol, but for a subprocess.""" def __init__(self, limit, loop): - super().__init__(loop=loop) + super(SubprocessStreamProtocol, self).__init__(loop=loop) self._limit = limit self.stdin = self.stdout = self.stderr = None self._transport = None @@ -115,7 +118,8 @@ class Process: """Wait until the process exit and return the process return code. This method is a coroutine.""" - return (yield from self._transport._wait()) + return_code = yield From(self._transport._wait()) + raise Return(return_code) def send_signal(self, signal): self._transport.send_signal(signal) @@ -134,7 +138,7 @@ class Process: logger.debug('%r communicate: feed stdin (%s bytes)', self, len(input)) try: - yield from self.stdin.drain() + yield From(self.stdin.drain()) except (BrokenPipeError, ConnectionResetError) as exc: # communicate() ignores BrokenPipeError and ConnectionResetError if debug: @@ -159,12 +163,12 @@ class Process: if self._loop.get_debug(): name = 'stdout' if fd == 1 else 'stderr' logger.debug('%r communicate: read %s', self, name) - output = yield from stream.read() + output = yield From(stream.read()) if self._loop.get_debug(): name = 'stdout' if fd == 1 else 'stderr' logger.debug('%r communicate: close %s', self, name) transport.close() - return output + raise Return(output) @coroutine def communicate(self, input=None): @@ -180,36 +184,43 @@ class Process: stderr = self._read_stream(2) else: stderr = self._noop() - stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr, - loop=self._loop) - yield from self.wait() - return (stdout, stderr) + stdin, stdout, stderr = yield From(tasks.gather(stdin, stdout, stderr, + loop=self._loop)) + yield From(self.wait()) + raise Return(stdout, stderr) @coroutine -def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, - loop=None, limit=streams._DEFAULT_LIMIT, **kwds): +def create_subprocess_shell(cmd, **kwds): + stdin = kwds.pop('stdin', None) + stdout = kwds.pop('stdout', None) + stderr = kwds.pop('stderr', None) + loop = kwds.pop('loop', None) + limit = kwds.pop('limit', streams._DEFAULT_LIMIT) if loop is None: loop = events.get_event_loop() protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, loop=loop) - transport, protocol = yield from loop.subprocess_shell( + transport, protocol = yield From(loop.subprocess_shell( protocol_factory, cmd, stdin=stdin, stdout=stdout, - stderr=stderr, **kwds) - return Process(transport, protocol, loop) + stderr=stderr, **kwds)) + raise Return(Process(transport, protocol, loop)) @coroutine -def create_subprocess_exec(program, *args, stdin=None, stdout=None, - stderr=None, loop=None, - limit=streams._DEFAULT_LIMIT, **kwds): +def create_subprocess_exec(program, *args, **kwds): + stdin = kwds.pop('stdin', None) + stdout = kwds.pop('stdout', None) + stderr = kwds.pop('stderr', None) + loop = kwds.pop('loop', None) + limit = kwds.pop('limit', streams._DEFAULT_LIMIT) if loop is None: loop = events.get_event_loop() protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, loop=loop) - transport, protocol = yield from loop.subprocess_exec( + transport, protocol = yield From(loop.subprocess_exec( protocol_factory, program, *args, stdin=stdin, stdout=stdout, - stderr=stderr, **kwds) - return Process(transport, protocol, loop) + stderr=stderr, **kwds)) + raise Return(Process(transport, protocol, loop)) diff --git a/trollius/tasks.py b/trollius/tasks.py index d8193ba..1fb23cd 100644 --- a/trollius/tasks.py +++ b/trollius/tasks.py @@ -14,16 +14,30 @@ import sys import types import traceback import warnings -import weakref +try: + from weakref import WeakSet +except ImportError: + # Python 2.6 + from .py27_weakrefset import WeakSet +from . import compat from . import coroutines from . import events +from . import executor from . import futures -from .coroutines import coroutine +from .locks import Lock, Condition, Semaphore, _ContextManager +from .coroutines import coroutine, From, Return + _PY34 = (sys.version_info >= (3, 4)) +@coroutine +def _lock_coroutine(lock): + yield From(lock.acquire()) + raise Return(_ContextManager(lock)) + + class Task(futures.Future): """A coroutine wrapped in a Future.""" @@ -37,7 +51,7 @@ class Task(futures.Future): # must be _wakeup(). # Weak set containing all tasks alive. - _all_tasks = weakref.WeakSet() + _all_tasks = WeakSet() # Dictionary containing tasks that are currently active in # all running event loops. {EventLoop: Task} @@ -67,11 +81,11 @@ class Task(futures.Future): """ if loop is None: loop = events.get_event_loop() - return {t for t in cls._all_tasks if t._loop is loop} + return set(t for t in cls._all_tasks if t._loop is loop) - def __init__(self, coro, *, loop=None): + def __init__(self, coro, loop=None): assert coroutines.iscoroutine(coro), repr(coro) - super().__init__(loop=loop) + super(Task, self).__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] self._coro = coro @@ -96,7 +110,7 @@ class Task(futures.Future): futures.Future.__del__(self) def _repr_info(self): - info = super()._repr_info() + info = super(Task, self)._repr_info() if self._must_cancel: # replace status @@ -109,7 +123,7 @@ class Task(futures.Future): info.insert(2, 'wait_for=%r' % self._fut_waiter) return info - def get_stack(self, *, limit=None): + def get_stack(self, limit=None): """Return the list of stack frames for this task's coroutine. If the coroutine is not done, this returns the stack where it is @@ -152,7 +166,7 @@ class Task(futures.Future): tb = tb.tb_next return frames - def print_stack(self, *, limit=None, file=None): + def print_stack(self, limit=None, file=None): """Print the stack or traceback for this task's coroutine. This produces output similar to that of the traceback module, @@ -219,9 +233,9 @@ class Task(futures.Future): self._must_cancel = True return True - def _step(self, value=None, exc=None): + def _step(self, value=None, exc=None, exc_tb=None): assert not self.done(), \ - '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + '_step(): already done: {0!r}, {1!r}, {2!r}'.format(self, value, exc) if self._must_cancel: if not isinstance(exc, futures.CancelledError): exc = futures.CancelledError() @@ -229,6 +243,10 @@ class Task(futures.Future): coro = self._coro self._fut_waiter = None + if exc_tb is not None: + init_exc = exc + else: + init_exc = None self.__class__._current_tasks[self._loop] = self # Call either coro.throw(exc) or coro.send(value). try: @@ -237,71 +255,104 @@ class Task(futures.Future): else: result = coro.send(value) except StopIteration as exc: - self.set_result(exc.value) - except futures.CancelledError as exc: - super().cancel() # I.e., Future.cancel(self). - except Exception as exc: - self.set_exception(exc) - except BaseException as exc: - self.set_exception(exc) - raise - else: - if isinstance(result, futures.Future): - # Yielded Future must come from Future.__iter__(). - if result._blocking: - result._blocking = False - result.add_done_callback(self._wakeup) - self._fut_waiter = result - if self._must_cancel: - if self._fut_waiter.cancel(): - self._must_cancel = False + if compat.PY33: + # asyncio Task object? get the result of the coroutine + result = exc.value + else: + if isinstance(exc, Return): + exc.raised = True + result = exc.value else: - self._loop.call_soon( - self._step, None, - RuntimeError( - 'yield was used instead of yield from ' - 'in task {!r} with {!r}'.format(self, result))) + result = None + self.set_result(result) + except futures.CancelledError as exc: + super(Task, self).cancel() # I.e., Future.cancel(self). + except BaseException as exc: + if exc is init_exc: + self._set_exception_with_tb(exc, exc_tb) + exc_tb = None + else: + self.set_exception(exc) + + if not isinstance(exc, Exception): + # reraise BaseException + raise + else: + if coroutines._DEBUG: + if not coroutines._coroutine_at_yield_from(self._coro): + # trollius coroutine must "yield From(...)" + if not isinstance(result, coroutines.FromWrapper): + self._loop.call_soon( + self._step, None, + RuntimeError("yield used without From")) + return + result = result.obj + else: + # asyncio coroutine using "yield from ..." + if isinstance(result, coroutines.FromWrapper): + result = result.obj + elif isinstance(result, coroutines.FromWrapper): + result = result.obj + + if coroutines.iscoroutine(result): + # "yield coroutine" creates a task, the current task + # will wait until the new task is done + result = self._loop.create_task(result) + # FIXME: faster check. common base class? hasattr? + elif isinstance(result, (Lock, Condition, Semaphore)): + coro = _lock_coroutine(result) + result = self._loop.create_task(coro) + + if isinstance(result, futures._FUTURE_CLASSES): + # Yielded Future must come from Future.__iter__(). + result.add_done_callback(self._wakeup) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False elif result is None: # Bare yield relinquishes control for one event loop iteration. self._loop.call_soon(self._step) - elif inspect.isgenerator(result): - # Yielding a generator is just wrong. - self._loop.call_soon( - self._step, None, - RuntimeError( - 'yield was used instead of yield from for ' - 'generator in task {!r} with {}'.format( - self, result))) else: # Yielding something else is an error. self._loop.call_soon( self._step, None, RuntimeError( - 'Task got bad yield: {!r}'.format(result))) + 'Task got bad yield: {0!r}'.format(result))) finally: self.__class__._current_tasks.pop(self._loop) self = None # Needed to break cycles when an exception occurs. def _wakeup(self, future): - try: - value = future.result() - except Exception as exc: - # This may also be a cancellation. - self._step(None, exc) + if (future._state == futures._FINISHED + and future._exception is not None): + # Get the traceback before calling exception(), because calling + # the exception() method clears the traceback + exc_tb = future._get_exception_tb() + exc = future.exception() + self._step(None, exc, exc_tb) + exc_tb = None else: - self._step(value, None) + try: + value = future.result() + except Exception as exc: + # This may also be a cancellation. + self._step(None, exc) + else: + self._step(value, None) self = None # Needed to break cycles when an exception occurs. # wait() and as_completed() similar to those in PEP 3148. -FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED -FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION -ALL_COMPLETED = concurrent.futures.ALL_COMPLETED +# Export symbols in trollius.tasks for compatibility with asyncio +FIRST_COMPLETED = executor.FIRST_COMPLETED +FIRST_EXCEPTION = executor.FIRST_EXCEPTION +ALL_COMPLETED = executor.ALL_COMPLETED @coroutine -def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): +def wait(fs, loop=None, timeout=None, return_when=ALL_COMPLETED): """Wait for the Futures and coroutines given by fs to complete. The sequence futures must not be empty. @@ -312,24 +363,25 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): Usage: - done, pending = yield from asyncio.wait(fs) + done, pending = yield From(asyncio.wait(fs)) Note: This does not raise TimeoutError! Futures that aren't done when the timeout occurs are returned in the second set. """ - if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): + if isinstance(fs, futures._FUTURE_CLASSES) or coroutines.iscoroutine(fs): raise TypeError("expect a list of futures, not %s" % type(fs).__name__) if not fs: raise ValueError('Set of coroutines/Futures is empty.') if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): - raise ValueError('Invalid return_when value: {}'.format(return_when)) + raise ValueError('Invalid return_when value: {0}'.format(return_when)) if loop is None: loop = events.get_event_loop() - fs = {ensure_future(f, loop=loop) for f in set(fs)} + fs = set(ensure_future(f, loop=loop) for f in set(fs)) - return (yield from _wait(fs, timeout, return_when, loop)) + result = yield From(_wait(fs, timeout, return_when, loop)) + raise Return(result) def _release_waiter(waiter, *args): @@ -338,7 +390,7 @@ def _release_waiter(waiter, *args): @coroutine -def wait_for(fut, timeout, *, loop=None): +def wait_for(fut, timeout, loop=None): """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. @@ -355,7 +407,8 @@ def wait_for(fut, timeout, *, loop=None): loop = events.get_event_loop() if timeout is None: - return (yield from fut) + result = yield From(fut) + raise Return(result) waiter = futures.Future(loop=loop) timeout_handle = loop.call_later(timeout, _release_waiter, waiter) @@ -367,14 +420,14 @@ def wait_for(fut, timeout, *, loop=None): try: # wait until the future completes or the timeout try: - yield from waiter + yield From(waiter) except futures.CancelledError: fut.remove_done_callback(cb) fut.cancel() raise if fut.done(): - return fut.result() + raise Return(fut.result()) else: fut.remove_done_callback(cb) fut.cancel() @@ -394,12 +447,11 @@ def _wait(fs, timeout, return_when, loop): timeout_handle = None if timeout is not None: timeout_handle = loop.call_later(timeout, _release_waiter, waiter) - counter = len(fs) + non_local = {'counter': len(fs)} def _on_completion(f): - nonlocal counter - counter -= 1 - if (counter <= 0 or + non_local['counter'] -= 1 + if (non_local['counter'] <= 0 or return_when == FIRST_COMPLETED or return_when == FIRST_EXCEPTION and (not f.cancelled() and f.exception() is not None)): @@ -412,7 +464,7 @@ def _wait(fs, timeout, return_when, loop): f.add_done_callback(_on_completion) try: - yield from waiter + yield From(waiter) finally: if timeout_handle is not None: timeout_handle.cancel() @@ -424,11 +476,11 @@ def _wait(fs, timeout, return_when, loop): done.add(f) else: pending.add(f) - return done, pending + raise Return(done, pending) # This is *not* a @coroutine! It is just an iterator (yielding Futures). -def as_completed(fs, *, loop=None, timeout=None): +def as_completed(fs, loop=None, timeout=None): """Return an iterator whose values are coroutines. When waiting for the yielded coroutines you'll get the results (or @@ -438,18 +490,18 @@ def as_completed(fs, *, loop=None, timeout=None): This differs from PEP 3148; the proper way to use this is: for f in as_completed(fs): - result = yield from f # The 'yield from' may raise. + result = yield From(f) # The 'yield' may raise. # Use result. - If a timeout is specified, the 'yield from' will raise + If a timeout is specified, the 'yield' will raise TimeoutError when the timeout occurs before all Futures are done. Note: The futures 'f' are not necessarily members of fs. """ - if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): + if isinstance(fs, futures._FUTURE_CLASSES) or coroutines.iscoroutine(fs): raise TypeError("expect a list of futures, not %s" % type(fs).__name__) loop = loop if loop is not None else events.get_event_loop() - todo = {ensure_future(f, loop=loop) for f in set(fs)} + todo = set(ensure_future(f, loop=loop) for f in set(fs)) from .queues import Queue # Import here to avoid circular import problem. done = Queue(loop=loop) timeout_handle = None @@ -470,11 +522,11 @@ def as_completed(fs, *, loop=None, timeout=None): @coroutine def _wait_for_one(): - f = yield from done.get() + f = yield From(done.get()) if f is None: # Dummy value from _on_timeout(). raise futures.TimeoutError - return f.result() # May raise f.exception(). + raise Return(f.result()) # May raise f.exception(). for f in todo: f.add_done_callback(_on_completion) @@ -485,18 +537,19 @@ def as_completed(fs, *, loop=None, timeout=None): @coroutine -def sleep(delay, result=None, *, loop=None): +def sleep(delay, result=None, loop=None): """Coroutine that completes after a given time (in seconds).""" future = futures.Future(loop=loop) h = future._loop.call_later(delay, future._set_result_unless_cancelled, result) try: - return (yield from future) + result = yield From(future) + raise Return(result) finally: h.cancel() -def async(coro_or_future, *, loop=None): +def async(coro_or_future, loop=None): """Wrap a coroutine in a future. If the argument is a Future, it is returned directly. @@ -515,7 +568,10 @@ def ensure_future(coro_or_future, *, loop=None): If the argument is a Future, it is returned directly. """ - if isinstance(coro_or_future, futures.Future): + # FIXME: only check if coroutines._DEBUG is True? + if isinstance(coro_or_future, coroutines.FromWrapper): + coro_or_future = coro_or_future.obj + if isinstance(coro_or_future, futures._FUTURE_CLASSES): if loop is not None and loop is not coro_or_future._loop: raise ValueError('loop argument must agree with Future') return coro_or_future @@ -538,8 +594,8 @@ class _GatheringFuture(futures.Future): cancelled. """ - def __init__(self, children, *, loop=None): - super().__init__(loop=loop) + def __init__(self, children, loop=None): + super(_GatheringFuture, self).__init__(loop=loop) self._children = children def cancel(self): @@ -550,7 +606,7 @@ class _GatheringFuture(futures.Future): return True -def gather(*coros_or_futures, loop=None, return_exceptions=False): +def gather(*coros_or_futures, **kw): """Return a future aggregating results from the given coroutines or futures. @@ -570,6 +626,11 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): prevent the cancellation of one child to cause other children to be cancelled.) """ + loop = kw.pop('loop', None) + return_exceptions = kw.pop('return_exceptions', False) + if kw: + raise TypeError("unexpected keyword") + if not coros_or_futures: outer = futures.Future(loop=loop) outer.set_result([]) @@ -577,7 +638,7 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): arg_to_fut = {} for arg in set(coros_or_futures): - if not isinstance(arg, futures.Future): + if not isinstance(arg, futures._FUTURE_CLASSES): fut = ensure_future(arg, loop=loop) if loop is None: loop = fut._loop @@ -595,11 +656,10 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): children = [arg_to_fut[arg] for arg in coros_or_futures] nchildren = len(children) outer = _GatheringFuture(children, loop=loop) - nfinished = 0 + non_local = {'nfinished': 0} results = [None] * nchildren def _done_callback(i, fut): - nonlocal nfinished if outer.done(): if not fut.cancelled(): # Mark exception retrieved. @@ -619,8 +679,8 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): else: res = fut._result results[i] = res - nfinished += 1 - if nfinished == nchildren: + non_local['nfinished'] += 1 + if non_local['nfinished'] == nchildren: outer.set_result(results) for i, fut in enumerate(children): @@ -628,16 +688,16 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): return outer -def shield(arg, *, loop=None): +def shield(arg, loop=None): """Wait for a future, shielding it from cancellation. The statement - res = yield from shield(something()) + res = yield From(shield(something())) is exactly equivalent to the statement - res = yield from something() + res = yield From(something()) *except* that if the coroutine containing it is cancelled, the task running in something() is not cancelled. From the POV of @@ -650,7 +710,7 @@ def shield(arg, *, loop=None): you can combine shield() with a try/except clause, as follows: try: - res = yield from shield(something()) + res = yield From(shield(something())) except CancelledError: res = None """ diff --git a/trollius/test_support.py b/trollius/test_support.py index 0fadfad..b40576a 100644 --- a/trollius/test_support.py +++ b/trollius/test_support.py @@ -4,6 +4,7 @@ # Ignore symbol TEST_HOME_DIR: test_events works without it +from __future__ import absolute_import import functools import gc import os @@ -14,6 +15,7 @@ import subprocess import sys import time +from trollius import test_utils # A constant likely larger than the underlying OS pipe buffer size, to # make writes blocking. @@ -39,7 +41,9 @@ def _assert_python(expected_success, *args, **env_vars): isolated = env_vars.pop('__isolated') else: isolated = not env_vars - cmd_line = [sys.executable, '-X', 'faulthandler'] + cmd_line = [sys.executable] + if sys.version_info >= (3, 3): + cmd_line.extend(('-X', 'faulthandler')) if isolated and sys.version_info >= (3, 4): # isolated mode: ignore Python environment variables, ignore user # site-packages, and don't add the current directory to sys.path @@ -248,7 +252,7 @@ def requires_mac_ver(*min_version): else: if version < min_version: min_version_txt = '.'.join(map(str, min_version)) - raise unittest.SkipTest( + raise test_utils.SkipTest( "Mac OS X %s or higher required, not %s" % (min_version_txt, version_txt)) return func(*args, **kw) @@ -275,7 +279,7 @@ def _requires_unix_version(sysname, min_version): else: if version < min_version: min_version_txt = '.'.join(map(str, min_version)) - raise unittest.SkipTest( + raise test_utils.SkipTest( "%s version %s or higher required, not %s" % (sysname, min_version_txt, version_txt)) return func(*args, **kw) @@ -300,9 +304,6 @@ except ImportError: # Use test.script_helper if available try: - from test.support.script_helper import assert_python_ok + from test.script_helper import assert_python_ok except ImportError: - try: - from test.script_helper import assert_python_ok - except ImportError: - pass + pass diff --git a/trollius/test_utils.py b/trollius/test_utils.py index af7f5bc..caa98f7 100644 --- a/trollius/test_utils.py +++ b/trollius/test_utils.py @@ -7,20 +7,32 @@ import logging import os import re import socket -import socketserver import sys import tempfile import threading import time -import unittest -from unittest import mock -from http.server import HTTPServer from wsgiref.simple_server import WSGIRequestHandler, WSGIServer +try: + import socketserver + from http.server import HTTPServer +except ImportError: + # Python 2 + import SocketServer as socketserver + from BaseHTTPServer import HTTPServer + +try: + from unittest import mock +except ImportError: + # Python < 3.3 + import mock + try: import ssl + from .py3_ssl import SSLContext, wrap_socket except ImportError: # pragma: no cover + # SSL support disabled in Python ssl = None from . import base_events @@ -37,27 +49,116 @@ if sys.platform == 'win32': # pragma: no cover else: from socket import socketpair # pragma: no cover +try: + import unittest + skipIf = unittest.skipIf + skipUnless = unittest.skipUnless + SkipTest = unittest.SkipTest + _TestCase = unittest.TestCase +except AttributeError: + # Python 2.6: use the backported unittest module called "unittest2" + import unittest2 + skipIf = unittest2.skipIf + skipUnless = unittest2.skipUnless + SkipTest = unittest2.SkipTest + _TestCase = unittest2.TestCase + + +if not hasattr(_TestCase, 'assertRaisesRegex'): + class _BaseTestCaseContext: + + def __init__(self, test_case): + self.test_case = test_case + + def _raiseFailure(self, standardMsg): + msg = self.test_case._formatMessage(self.msg, standardMsg) + raise self.test_case.failureException(msg) + + + class _AssertRaisesBaseContext(_BaseTestCaseContext): + + def __init__(self, expected, test_case, callable_obj=None, + expected_regex=None): + _BaseTestCaseContext.__init__(self, test_case) + self.expected = expected + self.test_case = test_case + if callable_obj is not None: + try: + self.obj_name = callable_obj.__name__ + except AttributeError: + self.obj_name = str(callable_obj) + else: + self.obj_name = None + if isinstance(expected_regex, (bytes, str)): + expected_regex = re.compile(expected_regex) + self.expected_regex = expected_regex + self.msg = None + + def handle(self, name, callable_obj, args, kwargs): + """ + If callable_obj is None, assertRaises/Warns is being used as a + context manager, so check for a 'msg' kwarg and return self. + If callable_obj is not None, call it passing args and kwargs. + """ + if callable_obj is None: + self.msg = kwargs.pop('msg', None) + return self + with self: + callable_obj(*args, **kwargs) + + + class _AssertRaisesContext(_AssertRaisesBaseContext): + """A context manager used to implement TestCase.assertRaises* methods.""" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + if exc_type is None: + try: + exc_name = self.expected.__name__ + except AttributeError: + exc_name = str(self.expected) + if self.obj_name: + self._raiseFailure("{0} not raised by {1}".format(exc_name, + self.obj_name)) + else: + self._raiseFailure("{0} not raised".format(exc_name)) + if not issubclass(exc_type, self.expected): + # let unexpected exceptions pass through + return False + self.exception = exc_value + if self.expected_regex is None: + return True + + expected_regex = self.expected_regex + if not expected_regex.search(str(exc_value)): + self._raiseFailure('"{0}" does not match "{1}"'.format( + expected_regex.pattern, str(exc_value))) + return True + def dummy_ssl_context(): if ssl is None: return None else: - return ssl.SSLContext(ssl.PROTOCOL_SSLv23) + return SSLContext(ssl.PROTOCOL_SSLv23) -def run_briefly(loop): +def run_briefly(loop, steps=1): @coroutine def once(): pass - gen = once() - t = loop.create_task(gen) - # Don't log a warning if the task is not done after run_until_complete(). - # It occurs if the loop is stopped or if a task raises a BaseException. - t._log_destroy_pending = False - try: - loop.run_until_complete(t) - finally: - gen.close() + for step in range(steps): + gen = once() + t = loop.create_task(gen) + # Don't log a warning if the task is not done after run_until_complete(). + # It occurs if the loop is stopped or if a task raises a BaseException. + t._log_destroy_pending = False + try: + loop.run_until_complete(t) + finally: + gen.close() def run_until(loop, pred, timeout=30): @@ -89,12 +190,12 @@ class SilentWSGIRequestHandler(WSGIRequestHandler): pass -class SilentWSGIServer(WSGIServer): +class SilentWSGIServer(WSGIServer, object): request_timeout = 2 def get_request(self): - request, client_addr = super().get_request() + request, client_addr = super(SilentWSGIServer, self).get_request() request.settimeout(self.request_timeout) return request, client_addr @@ -115,10 +216,10 @@ class SSLWSGIServerMixin: 'test', 'test_asyncio') keyfile = os.path.join(here, 'ssl_key.pem') certfile = os.path.join(here, 'ssl_cert.pem') - ssock = ssl.wrap_socket(request, - keyfile=keyfile, - certfile=certfile, - server_side=True) + ssock = wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) try: self.RequestHandlerClass(ssock, client_address, self) ssock.close() @@ -131,7 +232,7 @@ class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): pass -def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): +def _run_test_server(address, use_ssl, server_cls, server_ssl_cls): def app(environ, start_response): status = '200 OK' @@ -158,7 +259,7 @@ def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): if hasattr(socket, 'AF_UNIX'): - class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): + class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer, object): def server_bind(self): socketserver.UnixStreamServer.server_bind(self) @@ -166,7 +267,7 @@ if hasattr(socket, 'AF_UNIX'): self.server_port = 80 - class UnixWSGIServer(UnixHTTPServer, WSGIServer): + class UnixWSGIServer(UnixHTTPServer, WSGIServer, object): request_timeout = 2 @@ -175,7 +276,7 @@ if hasattr(socket, 'AF_UNIX'): self.setup_environ() def get_request(self): - request, client_addr = super().get_request() + request, client_addr = super(UnixWSGIServer, self).get_request() request.settimeout(self.request_timeout) # Code in the stdlib expects that get_request # will return a socket and a tuple (host, port). @@ -214,18 +315,20 @@ if hasattr(socket, 'AF_UNIX'): @contextlib.contextmanager - def run_test_unix_server(*, use_ssl=False): + def run_test_unix_server(use_ssl=False): with unix_socket_path() as path: - yield from _run_test_server(address=path, use_ssl=use_ssl, - server_cls=SilentUnixWSGIServer, - server_ssl_cls=UnixSSLWSGIServer) + for item in _run_test_server(address=path, use_ssl=use_ssl, + server_cls=SilentUnixWSGIServer, + server_ssl_cls=UnixSSLWSGIServer): + yield item @contextlib.contextmanager -def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): - yield from _run_test_server(address=(host, port), use_ssl=use_ssl, - server_cls=SilentWSGIServer, - server_ssl_cls=SSLWSGIServer) +def run_test_server(host='127.0.0.1', port=0, use_ssl=False): + for item in _run_test_server(address=(host, port), use_ssl=use_ssl, + server_cls=SilentWSGIServer, + server_ssl_cls=SSLWSGIServer): + yield item def make_test_protocol(base): @@ -278,7 +381,7 @@ class TestLoop(base_events.BaseEventLoop): """ def __init__(self, gen=None): - super().__init__() + super(TestLoop, self).__init__() if gen is None: def gen(): @@ -307,7 +410,7 @@ class TestLoop(base_events.BaseEventLoop): self._time += advance def close(self): - super().close() + super(TestLoop, self).close() if self._check_on_close: try: self._gen.send(0) @@ -328,11 +431,11 @@ class TestLoop(base_events.BaseEventLoop): return False def assert_reader(self, fd, callback, *args): - assert fd in self.readers, 'fd {} is not registered'.format(fd) + assert fd in self.readers, 'fd {0} is not registered'.format(fd) handle = self.readers[fd] - assert handle._callback == callback, '{!r} != {!r}'.format( + assert handle._callback == callback, '{0!r} != {1!r}'.format( handle._callback, callback) - assert handle._args == args, '{!r} != {!r}'.format( + assert handle._args == args, '{0!r} != {1!r}'.format( handle._args, args) def add_writer(self, fd, callback, *args): @@ -347,11 +450,11 @@ class TestLoop(base_events.BaseEventLoop): return False def assert_writer(self, fd, callback, *args): - assert fd in self.writers, 'fd {} is not registered'.format(fd) + assert fd in self.writers, 'fd {0} is not registered'.format(fd) handle = self.writers[fd] - assert handle._callback == callback, '{!r} != {!r}'.format( + assert handle._callback == callback, '{0!r} != {1!r}'.format( handle._callback, callback) - assert handle._args == args, '{!r} != {!r}'.format( + assert handle._args == args, '{0!r} != {1!r}'.format( handle._args, args) def reset_counters(self): @@ -359,7 +462,7 @@ class TestLoop(base_events.BaseEventLoop): self.remove_writer_count = collections.defaultdict(int) def _run_once(self): - super()._run_once() + super(TestLoop, self)._run_once() for when in self._timers: advance = self._gen.send(when) self.advance_time(advance) @@ -367,7 +470,7 @@ class TestLoop(base_events.BaseEventLoop): def call_at(self, when, callback, *args): self._timers.append(when) - return super().call_at(when, callback, *args) + return super(TestLoop, self).call_at(when, callback, *args) def _process_events(self, event_list): return @@ -400,8 +503,8 @@ def get_function_source(func): return source -class TestCase(unittest.TestCase): - def set_event_loop(self, loop, *, cleanup=True): +class TestCase(_TestCase): + def set_event_loop(self, loop, cleanup=True): assert loop is not None # ensure that the event loop is passed explicitly in asyncio events.set_event_loop(None) @@ -420,6 +523,48 @@ class TestCase(unittest.TestCase): # in an except block of a generator self.assertEqual(sys.exc_info(), (None, None, None)) + if not hasattr(_TestCase, 'assertRaisesRegex'): + def assertRaisesRegex(self, expected_exception, expected_regex, + callable_obj=None, *args, **kwargs): + """Asserts that the message in a raised exception matches a regex. + + Args: + expected_exception: Exception class expected to be raised. + expected_regex: Regex (re pattern object or string) expected + to be found in error message. + callable_obj: Function to be called. + msg: Optional message used in case of failure. Can only be used + when assertRaisesRegex is used as a context manager. + args: Extra args. + kwargs: Extra kwargs. + """ + context = _AssertRaisesContext(expected_exception, self, callable_obj, + expected_regex) + + return context.handle('assertRaisesRegex', callable_obj, args, kwargs) + + if not hasattr(_TestCase, 'assertRegex'): + def assertRegex(self, text, expected_regex, msg=None): + """Fail the test unless the text matches the regular expression.""" + if isinstance(expected_regex, (str, bytes)): + assert expected_regex, "expected_regex must not be empty." + expected_regex = re.compile(expected_regex) + if not expected_regex.search(text): + msg = msg or "Regex didn't match" + msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text) + raise self.failureException(msg) + + def check_soure_traceback(self, source_traceback, lineno_delta): + frame = sys._getframe(1) + filename = frame.f_code.co_filename + lineno = frame.f_lineno + lineno_delta + name = frame.f_code.co_name + self.assertIsInstance(source_traceback, list) + self.assertEqual(source_traceback[-1][:3], + (filename, + lineno, + name)) + @contextlib.contextmanager def disable_logger(): diff --git a/trollius/transports.py b/trollius/transports.py index 22df3c7..5bdbdaf 100644 --- a/trollius/transports.py +++ b/trollius/transports.py @@ -1,6 +1,7 @@ """Abstract Transport class.""" import sys +from .compat import flatten_bytes _PY34 = sys.version_info >= (3, 4) @@ -9,7 +10,7 @@ __all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport', ] -class BaseTransport: +class BaseTransport(object): """Base class for transports.""" def __init__(self, extra=None): @@ -94,12 +95,8 @@ class WriteTransport(BaseTransport): The default implementation concatenates the arguments and calls write() on the result. """ - if not _PY34: - # In Python 3.3, bytes.join() doesn't handle memoryview. - list_of_data = ( - bytes(data) if isinstance(data, memoryview) else data - for data in list_of_data) - self.write(b''.join(list_of_data)) + data = map(flatten_bytes, list_of_data) + self.write(b''.join(data)) def write_eof(self): """Close the write end after flushing buffered data. @@ -230,7 +227,7 @@ class _FlowControlMixin(Transport): override set_write_buffer_limits() (e.g. to specify different defaults). - The subclass constructor must call super().__init__(extra). This + The subclass constructor must call super(Class, self).__init__(extra). This will call set_write_buffer_limits(). The user may call set_write_buffer_limits() and @@ -239,7 +236,7 @@ class _FlowControlMixin(Transport): """ def __init__(self, extra=None, loop=None): - super().__init__(extra) + super(_FlowControlMixin, self).__init__(extra) assert loop is not None self._loop = loop self._protocol_paused = False diff --git a/trollius/unix_events.py b/trollius/unix_events.py index 75e7c9c..fcccaaa 100644 --- a/trollius/unix_events.py +++ b/trollius/unix_events.py @@ -1,4 +1,5 @@ """Selector event loop for Unix with signal handling.""" +from __future__ import absolute_import import errno import os @@ -13,6 +14,7 @@ import warnings from . import base_events from . import base_subprocess +from . import compat from . import constants from . import coroutines from . import events @@ -20,8 +22,13 @@ from . import futures from . import selector_events from . import selectors from . import transports -from .coroutines import coroutine +from .compat import flatten_bytes +from .coroutines import coroutine, From, Return from .log import logger +from .py33_exceptions import ( + reraise, wrap_error, + BlockingIOError, BrokenPipeError, ConnectionResetError, + InterruptedError, ChildProcessError) __all__ = ['SelectorEventLoop', @@ -33,9 +40,10 @@ if sys.platform == 'win32': # pragma: no cover raise ImportError('Signals are not really supported on Windows') -def _sighandler_noop(signum, frame): - """Dummy signal handler.""" - pass +if compat.PY33: + def _sighandler_noop(signum, frame): + """Dummy signal handler.""" + pass class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): @@ -45,23 +53,27 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): """ def __init__(self, selector=None): - super().__init__(selector) + super(_UnixSelectorEventLoop, self).__init__(selector) self._signal_handlers = {} def _socketpair(self): return socket.socketpair() def close(self): - super().close() + super(_UnixSelectorEventLoop, self).close() for sig in list(self._signal_handlers): self.remove_signal_handler(sig) - def _process_self_data(self, data): - for signum in data: - if not signum: - # ignore null bytes written by _write_to_self() - continue - self._handle_signal(signum) + # On Python <= 3.2, the C signal handler of Python writes a null byte into + # the wakeup file descriptor. We cannot retrieve the signal numbers from + # the file descriptor. + if compat.PY33: + def _process_self_data(self, data): + for signum in data: + if not signum: + # ignore null bytes written by _write_to_self() + continue + self._handle_signal(signum) def add_signal_handler(self, sig, callback, *args): """Add a handler for a signal. UNIX only. @@ -88,14 +100,30 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): self._signal_handlers[sig] = handle try: - # Register a dummy signal handler to ask Python to write the signal - # number in the wakup file descriptor. _process_self_data() will - # read signal numbers from this file descriptor to handle signals. - signal.signal(sig, _sighandler_noop) + if compat.PY33: + # On Python 3.3 and newer, the C signal handler writes the + # signal number into the wakeup file descriptor and then calls + # Py_AddPendingCall() to schedule the Python signal handler. + # + # Register a dummy signal handler to ask Python to write the + # signal number into the wakup file descriptor. + # _process_self_data() will read signal numbers from this file + # descriptor to handle signals. + signal.signal(sig, _sighandler_noop) + else: + # On Python 3.2 and older, the C signal handler first calls + # Py_AddPendingCall() to schedule the Python signal handler, + # and then write a null byte into the wakeup file descriptor. + signal.signal(sig, self._handle_signal) # Set SA_RESTART to limit EINTR occurrences. signal.siginterrupt(sig, False) - except OSError as exc: + except (RuntimeError, OSError) as exc: + # On Python 2, signal.signal(signal.SIGKILL, signal.SIG_IGN) raises + # RuntimeError(22, 'Invalid argument'). On Python 3, + # OSError(22, 'Invalid argument') is raised instead. + exc_type, exc_value, tb = sys.exc_info() + del self._signal_handlers[sig] if not self._signal_handlers: try: @@ -103,12 +131,12 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): except (ValueError, OSError) as nexc: logger.info('set_wakeup_fd(-1) failed: %s', nexc) - if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) + if isinstance(exc, RuntimeError) or exc.errno == errno.EINVAL: + raise RuntimeError('sig {0} cannot be caught'.format(sig)) else: - raise + reraise(exc_type, exc_value, tb) - def _handle_signal(self, sig): + def _handle_signal(self, sig, frame=None): """Internal helper that is the actual signal handler.""" handle = self._signal_handlers.get(sig) if handle is None: @@ -138,7 +166,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): signal.signal(sig, handler) except OSError as exc: if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) + raise RuntimeError('sig {0} cannot be caught'.format(sig)) else: raise @@ -157,11 +185,11 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): Raise RuntimeError if there is a problem setting up the handler. """ if not isinstance(sig, int): - raise TypeError('sig must be an int, not {!r}'.format(sig)) + raise TypeError('sig must be an int, not {0!r}'.format(sig)) if not (1 <= sig < signal.NSIG): raise ValueError( - 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + 'sig {0} out of range(1, {1})'.format(sig, signal.NSIG)) def _make_read_pipe_transport(self, pipe, protocol, waiter=None, extra=None): @@ -185,7 +213,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): watcher.add_child_handler(transp.get_pid(), self._child_watcher_callback, transp) try: - yield from waiter + yield From(waiter) except Exception as exc: # Workaround CPython bug #23353: using yield/yield-from in an # except block of a generator doesn't clear properly @@ -196,16 +224,16 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): if err is not None: transp.close() - yield from transp._wait() + yield From(transp._wait()) raise err - return transp + raise Return(transp) def _child_watcher_callback(self, pid, returncode, transp): self.call_soon_threadsafe(transp._process_exited, returncode) @coroutine - def create_unix_connection(self, protocol_factory, path, *, + def create_unix_connection(self, protocol_factory, path, ssl=None, sock=None, server_hostname=None): assert server_hostname is None or isinstance(server_hostname, str) @@ -225,7 +253,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) try: sock.setblocking(False) - yield from self.sock_connect(sock, path) + yield From(self.sock_connect(sock, path)) except: sock.close() raise @@ -235,12 +263,12 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): raise ValueError('no path and sock were specified') sock.setblocking(False) - transport, protocol = yield from self._create_connection_transport( - sock, protocol_factory, ssl, server_hostname) - return transport, protocol + transport, protocol = yield From(self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname)) + raise Return(transport, protocol) @coroutine - def create_unix_server(self, protocol_factory, path=None, *, + def create_unix_server(self, protocol_factory, path=None, sock=None, backlog=100, ssl=None): if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') @@ -254,13 +282,13 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): try: sock.bind(path) - except OSError as exc: + except socket.error as exc: sock.close() if exc.errno == errno.EADDRINUSE: # Let's improve the error message by adding # with what exact address it occurs. - msg = 'Address {!r} is already in use'.format(path) - raise OSError(errno.EADDRINUSE, msg) from None + msg = 'Address {0!r} is already in use'.format(path) + raise OSError(errno.EADDRINUSE, msg) else: raise except: @@ -273,7 +301,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): if sock.family != socket.AF_UNIX: raise ValueError( - 'A UNIX Domain Socket was expected, got {!r}'.format(sock)) + 'A UNIX Domain Socket was expected, got {0!r}'.format(sock)) server = base_events.Server(self, [sock]) sock.listen(backlog) @@ -283,6 +311,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): if hasattr(os, 'set_blocking'): + # Python 3.5 and newer def _set_nonblocking(fd): os.set_blocking(fd, False) else: @@ -299,7 +328,7 @@ class _UnixReadPipeTransport(transports.ReadTransport): max_size = 256 * 1024 # max bytes we read in one event loop iteration def __init__(self, loop, pipe, protocol, waiter=None, extra=None): - super().__init__(extra) + super(_UnixReadPipeTransport, self).__init__(extra) self._extra['pipe'] = pipe self._loop = loop self._pipe = pipe @@ -341,7 +370,7 @@ class _UnixReadPipeTransport(transports.ReadTransport): def _read_ready(self): try: - data = os.read(self._fileno, self.max_size) + data = wrap_error(os.read, self._fileno, self.max_size) except (BlockingIOError, InterruptedError): pass except OSError as exc: @@ -409,7 +438,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, transports.WriteTransport): def __init__(self, loop, pipe, protocol, waiter=None, extra=None): - super().__init__(extra, loop) + super(_UnixWritePipeTransport, self).__init__(extra, loop) self._extra['pipe'] = pipe self._pipe = pipe self._fileno = pipe.fileno() @@ -475,9 +504,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, self._close() def write(self, data): - assert isinstance(data, (bytes, bytearray, memoryview)), repr(data) - if isinstance(data, bytearray): - data = memoryview(data) + data = flatten_bytes(data) if not data: return @@ -491,7 +518,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, if not self._buffer: # Attempt to send it right away first. try: - n = os.write(self._fileno, data) + n = wrap_error(os.write, self._fileno, data) except (BlockingIOError, InterruptedError): n = 0 except Exception as exc: @@ -511,9 +538,9 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, data = b''.join(self._buffer) assert data, 'Data should not be empty' - self._buffer.clear() + del self._buffer[:] try: - n = os.write(self._fileno, data) + n = wrap_error(os.write, self._fileno, data) except (BlockingIOError, InterruptedError): self._buffer.append(data) except Exception as exc: @@ -582,7 +609,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, self._closing = True if self._buffer: self._loop.remove_writer(self._fileno) - self._buffer.clear() + del self._buffer[:] self._loop.remove_reader(self._fileno) self._loop.call_soon(self._call_connection_lost, exc) @@ -633,11 +660,20 @@ class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport): args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, universal_newlines=False, bufsize=bufsize, **kwargs) if stdin_w is not None: + # Retrieve the file descriptor from stdin_w, stdin_w should not + # "own" the file descriptor anymore: closing stdin_fd file + # descriptor must close immediatly the file stdin.close() - self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize) + if hasattr(stdin_w, 'detach'): + stdin_fd = stdin_w.detach() + self._proc.stdin = os.fdopen(stdin_fd, 'wb', bufsize) + else: + stdin_dup = os.dup(stdin_w.fileno()) + stdin_w.close() + self._proc.stdin = os.fdopen(stdin_dup, 'wb', bufsize) -class AbstractChildWatcher: +class AbstractChildWatcher(object): """Abstract base class for monitoring child processes. Objects derived from this class monitor a collection of subprocesses and @@ -773,12 +809,12 @@ class SafeChildWatcher(BaseChildWatcher): """ def __init__(self): - super().__init__() + super(SafeChildWatcher, self).__init__() self._callbacks = {} def close(self): self._callbacks.clear() - super().close() + super(SafeChildWatcher, self).close() def __enter__(self): return self @@ -850,7 +886,7 @@ class FastChildWatcher(BaseChildWatcher): (O(1) each time a child terminates). """ def __init__(self): - super().__init__() + super(FastChildWatcher, self).__init__() self._callbacks = {} self._lock = threading.Lock() self._zombies = {} @@ -859,7 +895,7 @@ class FastChildWatcher(BaseChildWatcher): def close(self): self._callbacks.clear() self._zombies.clear() - super().close() + super(FastChildWatcher, self).close() def __enter__(self): with self._lock: @@ -906,7 +942,7 @@ class FastChildWatcher(BaseChildWatcher): # long as we're able to reap a child. while True: try: - pid, status = os.waitpid(-1, os.WNOHANG) + pid, status = wrap_error(os.waitpid, -1, os.WNOHANG) except ChildProcessError: # No more child processes exist. return @@ -949,7 +985,7 @@ class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): _loop_factory = _UnixSelectorEventLoop def __init__(self): - super().__init__() + super(_UnixDefaultEventLoopPolicy, self).__init__() self._watcher = None def _init_watcher(self): @@ -968,7 +1004,7 @@ class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): the child watcher. """ - super().set_event_loop(loop) + super(_UnixDefaultEventLoopPolicy, self).set_event_loop(loop) if self._watcher is not None and \ isinstance(threading.current_thread(), threading._MainThread): diff --git a/trollius/windows_events.py b/trollius/windows_events.py index 922594f..7f7764e 100644 --- a/trollius/windows_events.py +++ b/trollius/windows_events.py @@ -11,12 +11,15 @@ from . import events from . import base_subprocess from . import futures from . import proactor_events +from . import py33_winapi as _winapi from . import selector_events from . import tasks from . import windows_utils from . import _overlapped -from .coroutines import coroutine +from .coroutines import coroutine, From, Return from .log import logger +from .py33_exceptions import (wrap_error, get_error_class, + ConnectionRefusedError, BrokenPipeError) __all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', @@ -42,14 +45,14 @@ class _OverlappedFuture(futures.Future): Cancelling it will immediately cancel the overlapped operation. """ - def __init__(self, ov, *, loop=None): - super().__init__(loop=loop) + def __init__(self, ov, loop=None): + super(_OverlappedFuture, self).__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] self._ov = ov def _repr_info(self): - info = super()._repr_info() + info = super(_OverlappedFuture, self)._repr_info() if self._ov is not None: state = 'pending' if self._ov.pending else 'completed' info.insert(1, 'overlapped=<%s, %#x>' % (state, self._ov.address)) @@ -73,22 +76,22 @@ class _OverlappedFuture(futures.Future): def cancel(self): self._cancel_overlapped() - return super().cancel() + return super(_OverlappedFuture, self).cancel() def set_exception(self, exception): - super().set_exception(exception) + super(_OverlappedFuture, self).set_exception(exception) self._cancel_overlapped() def set_result(self, result): - super().set_result(result) + super(_OverlappedFuture, self).set_result(result) self._ov = None class _BaseWaitHandleFuture(futures.Future): """Subclass of Future which represents a wait handle.""" - def __init__(self, ov, handle, wait_handle, *, loop=None): - super().__init__(loop=loop) + def __init__(self, ov, handle, wait_handle, loop=None): + super(_BaseWaitHandleFuture, self).__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] # Keep a reference to the Overlapped object to keep it alive until the @@ -107,7 +110,7 @@ class _BaseWaitHandleFuture(futures.Future): _winapi.WAIT_OBJECT_0) def _repr_info(self): - info = super()._repr_info() + info = super(_BaseWaitHandleFuture, self)._repr_info() info.append('handle=%#x' % self._handle) if self._handle is not None: state = 'signaled' if self._poll() else 'waiting' @@ -147,15 +150,15 @@ class _BaseWaitHandleFuture(futures.Future): def cancel(self): self._unregister_wait() - return super().cancel() + return super(_BaseWaitHandleFuture, self).cancel() def set_exception(self, exception): self._unregister_wait() - super().set_exception(exception) + super(_BaseWaitHandleFuture, self).set_exception(exception) def set_result(self, result): self._unregister_wait() - super().set_result(result) + super(_BaseWaitHandleFuture, self).set_result(result) class _WaitCancelFuture(_BaseWaitHandleFuture): @@ -163,8 +166,9 @@ class _WaitCancelFuture(_BaseWaitHandleFuture): _WaitHandleFuture using an event. """ - def __init__(self, ov, event, wait_handle, *, loop=None): - super().__init__(ov, event, wait_handle, loop=loop) + def __init__(self, ov, event, wait_handle, loop=None): + super(_WaitCancelFuture, self).__init__(ov, event, wait_handle, + loop=loop) self._done_callback = None @@ -178,8 +182,9 @@ class _WaitCancelFuture(_BaseWaitHandleFuture): class _WaitHandleFuture(_BaseWaitHandleFuture): - def __init__(self, ov, handle, wait_handle, proactor, *, loop=None): - super().__init__(ov, handle, wait_handle, loop=loop) + def __init__(self, ov, handle, wait_handle, proactor, loop=None): + super(_WaitHandleFuture, self).__init__(ov, handle, wait_handle, + loop=loop) self._proactor = proactor self._unregister_proactor = True self._event = _overlapped.CreateEvent(None, True, False, None) @@ -201,7 +206,7 @@ class _WaitHandleFuture(_BaseWaitHandleFuture): self._proactor._unregister(self._ov) self._proactor = None - super()._unregister_wait_cb(fut) + super(_WaitHandleFuture, self)._unregister_wait_cb(fut) def _unregister_wait(self): if not self._registered: @@ -259,7 +264,7 @@ class PipeServer(object): flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED if first: flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE - h = _winapi.CreateNamedPipe( + h = wrap_error(_winapi.CreateNamedPipe, self._address, flags, _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | _winapi.PIPE_WAIT, @@ -301,7 +306,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop): def __init__(self, proactor=None): if proactor is None: proactor = IocpProactor() - super().__init__(proactor) + super(ProactorEventLoop, self).__init__(proactor) def _socketpair(self): return windows_utils.socketpair() @@ -309,11 +314,11 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop): @coroutine def create_pipe_connection(self, protocol_factory, address): f = self._proactor.connect_pipe(address) - pipe = yield from f + pipe = yield From(f) protocol = protocol_factory() trans = self._make_duplex_pipe_transport(pipe, protocol, extra={'addr': address}) - return trans, protocol + raise Return(trans, protocol) @coroutine def start_serving_pipe(self, protocol_factory, address): @@ -372,7 +377,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop): waiter=waiter, extra=extra, **kwargs) try: - yield from waiter + yield From(waiter) except Exception as exc: # Workaround CPython bug #23353: using yield/yield-from in an # except block of a generator doesn't clear properly sys.exc_info() @@ -382,13 +387,13 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop): if err is not None: transp.close() - yield from transp._wait() + yield From(transp._wait()) raise err - return transp + raise Return(transp) -class IocpProactor: +class IocpProactor(object): """Proactor implementation using IOCP.""" def __init__(self, concurrency=0xffffffff): @@ -426,16 +431,16 @@ class IocpProactor: ov = _overlapped.Overlapped(NULL) try: if isinstance(conn, socket.socket): - ov.WSARecv(conn.fileno(), nbytes, flags) + wrap_error(ov.WSARecv, conn.fileno(), nbytes, flags) else: - ov.ReadFile(conn.fileno(), nbytes) + wrap_error(ov.ReadFile, conn.fileno(), nbytes) except BrokenPipeError: return self._result(b'') def finish_recv(trans, key, ov): try: - return ov.getresult() - except OSError as exc: + return wrap_error(ov.getresult) + except WindowsError as exc: if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: raise ConnectionResetError(*exc.args) else: @@ -453,8 +458,8 @@ class IocpProactor: def finish_send(trans, key, ov): try: - return ov.getresult() - except OSError as exc: + return wrap_error(ov.getresult) + except WindowsError as exc: if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: raise ConnectionResetError(*exc.args) else: @@ -469,7 +474,7 @@ class IocpProactor: ov.AcceptEx(listener.fileno(), conn.fileno()) def finish_accept(trans, key, ov): - ov.getresult() + wrap_error(ov.getresult) # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. buf = struct.pack('@P', listener.fileno()) conn.setsockopt(socket.SOL_SOCKET, @@ -481,7 +486,7 @@ class IocpProactor: def accept_coro(future, conn): # Coroutine closing the accept socket if the future is cancelled try: - yield from future + yield From(future) except futures.CancelledError: conn.close() raise @@ -496,7 +501,7 @@ class IocpProactor: # The socket needs to be locally bound before we call ConnectEx(). try: _overlapped.BindLocal(conn.fileno(), conn.family) - except OSError as e: + except WindowsError as e: if e.winerror != errno.WSAEINVAL: raise # Probably already locally bound; check using getsockname(). @@ -506,7 +511,7 @@ class IocpProactor: ov.ConnectEx(conn.fileno(), address) def finish_connect(trans, key, ov): - ov.getresult() + wrap_error(ov.getresult) # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. conn.setsockopt(socket.SOL_SOCKET, _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) @@ -526,7 +531,7 @@ class IocpProactor: return self._result(pipe) def finish_accept_pipe(trans, key, ov): - ov.getresult() + wrap_error(ov.getresult) return pipe return self._register(ov, pipe, finish_accept_pipe) @@ -539,17 +544,17 @@ class IocpProactor: # Call CreateFile() in a loop until it doesn't fail with # ERROR_PIPE_BUSY try: - handle = _overlapped.ConnectPipe(address) + handle = wrap_error(_overlapped.ConnectPipe, address) break - except OSError as exc: + except WindowsError as exc: if exc.winerror != _overlapped.ERROR_PIPE_BUSY: raise # ConnectPipe() failed with ERROR_PIPE_BUSY: retry later delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) - yield from tasks.sleep(delay, loop=self._loop) + yield From(tasks.sleep(delay, loop=self._loop)) - return windows_utils.PipeHandle(handle) + raise Return(windows_utils.PipeHandle(handle)) def wait_for_handle(self, handle, timeout=None): """Wait for a handle. @@ -572,7 +577,7 @@ class IocpProactor: else: # RegisterWaitForSingleObject() has a resolution of 1 millisecond, # round away from zero to wait *at least* timeout seconds. - ms = math.ceil(timeout * 1e3) + ms = int(math.ceil(timeout * 1e3)) # We only create ov so we can use ov.address as a key for the cache. ov = _overlapped.Overlapped(NULL) @@ -660,7 +665,7 @@ class IocpProactor: else: # GetQueuedCompletionStatus() has a resolution of 1 millisecond, # round away from zero to wait *at least* timeout seconds. - ms = math.ceil(timeout * 1e3) + ms = int(math.ceil(timeout * 1e3)) if ms >= INFINITE: raise ValueError("timeout too big") @@ -705,7 +710,7 @@ class IocpProactor: # Remove unregisted futures for ov in self._unregistered: self._cache.pop(ov.address, None) - self._unregistered.clear() + del self._unregistered[:] def _stop_serving(self, obj): # obj is a socket or pipe handle. It will be closed in diff --git a/trollius/windows_utils.py b/trollius/windows_utils.py index 870cd13..b2b9617 100644 --- a/trollius/windows_utils.py +++ b/trollius/windows_utils.py @@ -1,6 +1,7 @@ """ Various Windows specific bits and pieces """ +from __future__ import absolute_import import sys @@ -16,6 +17,9 @@ import subprocess import tempfile import warnings +from . import py33_winapi as _winapi +from .py33_exceptions import wrap_error, BlockingIOError, InterruptedError + __all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] @@ -64,7 +68,7 @@ else: try: csock.setblocking(False) try: - csock.connect((addr, port)) + wrap_error(csock.connect, (addr, port)) except (BlockingIOError, InterruptedError): pass csock.setblocking(True) @@ -80,7 +84,7 @@ else: # Replacement for os.pipe() using handles instead of fds -def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): +def pipe(duplex=False, overlapped=(True, True), bufsize=BUFSIZE): """Like os.pipe() but with overlapped support and using handles not fds.""" address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % (os.getpid(), next(_mmap_counter))) @@ -115,7 +119,12 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): flags_and_attribs, _winapi.NULL) ov = _winapi.ConnectNamedPipe(h1, overlapped=True) - ov.GetOverlappedResult(True) + if hasattr(ov, 'GetOverlappedResult'): + # _winapi module of Python 3.3 + ov.GetOverlappedResult(True) + else: + # _overlapped module + wrap_error(ov.getresult, True) return h1, h2 except: if h1 is not None: @@ -128,7 +137,7 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): # Wrapper for a pipe handle -class PipeHandle: +class PipeHandle(object): """Wrapper for an overlapped pipe handle which is vaguely file-object like. The IOCP event loop can use these instead of socket objects. @@ -152,7 +161,7 @@ class PipeHandle: raise ValueError("I/O operatioon on closed pipe") return self._handle - def close(self, *, CloseHandle=_winapi.CloseHandle): + def close(self, CloseHandle=_winapi.CloseHandle): if self._handle is not None: CloseHandle(self._handle) self._handle = None @@ -200,8 +209,11 @@ class Popen(subprocess.Popen): else: stderr_wfd = stderr try: - super().__init__(args, stdin=stdin_rfd, stdout=stdout_wfd, - stderr=stderr_wfd, **kwds) + super(Popen, self).__init__(args, + stdin=stdin_rfd, + stdout=stdout_wfd, + stderr=stderr_wfd, + **kwds) except: for h in (stdin_wh, stdout_rh, stderr_rh): if h is not None: