From c5bfa24e298527a61edb7ece6102466b259aa566 Mon Sep 17 00:00:00 2001 From: Kenneth Giusti Date: Mon, 24 Aug 2015 09:58:17 -0400 Subject: [PATCH] Fix reentrancy issues --- examples/recv.py | 75 ++++++++--- examples/rpc-client.py | 20 ++- examples/rpc-server.py | 42 +++++-- examples/send.py | 45 ++++++- examples/server.py | 67 ++++++---- examples/utils.py | 20 ++- pyngus/connection.py | 223 ++++++++++++++++++++------------- pyngus/link.py | 108 ++++++++++++---- pyngus/sockets.py | 85 +++++++------ setup.py | 2 +- tests/unit_tests/common.py | 51 +++++--- tests/unit_tests/connection.py | 8 +- tests/unit_tests/link.py | 60 +++++++++ 13 files changed, 562 insertions(+), 244 deletions(-) diff --git a/examples/recv.py b/examples/recv.py index b67686d..c263462 100755 --- a/examples/recv.py +++ b/examples/recv.py @@ -33,6 +33,42 @@ LOG = logging.getLogger() LOG.addHandler(logging.StreamHandler()) +class ConnectionEventHandler(pyngus.ConnectionEventHandler): + def connection_failed(self, connection, error): + """Connection has failed in some way.""" + LOG.warn("Connection failed: %s", error) + connection.close() + + def connection_remote_closed(self, connection, pn_condition): + """Peer has closed its end of the connection.""" + LOG.debug("connection_remote_closed condition=%s", pn_condition) + connection.close() + + +class ReceiverEventHandler(pyngus.ReceiverEventHandler): + def __init__(self): + self.done = False + self.message = None + self.handle = None + + def receiver_remote_closed(self, receiver_link, pn_condition): + """Peer has closed its end of the link.""" + LOG.debug("receiver_remote_closed condition=%s", pn_condition) + receiver_link.close() + self.done = True + + def receiver_failed(self, receiver_link, error): + """Protocol error occurred.""" + LOG.warn("receiver_failed error=%s", error) + receiver_link.close() + self.done = True + + def message_received(self, receiver, message, handle): + self.done = True + self.message = message + self.handle = handle + + def main(argv=None): _usage = """Usage: %prog [options]""" @@ -53,6 +89,12 @@ def main(argv=None): help="enable protocol tracing") parser.add_option("--ca", help="Certificate Authority PEM file") + parser.add_option("--username", type="string", + help="User Id for authentication") + parser.add_option("--password", type="string", + help="User password for authentication") + parser.add_option("--sasl-mechs", type="string", + help="The list of acceptable SASL mechs") opts, extra = parser.parse_args(args=argv) if opts.debug: @@ -64,35 +106,28 @@ def main(argv=None): # container = pyngus.Container(uuid.uuid4().hex) conn_properties = {'hostname': host, - 'x-server': False, - 'x-username': 'guest', - 'x-password': 'guest', - 'x-sasl-mechs': "ANONYMOUS PLAIN"} + 'x-server': False} if opts.trace: conn_properties["x-trace-protocol"] = True if opts.ca: conn_properties["x-ssl-ca-file"] = opts.ca if opts.idle_timeout: conn_properties["idle-time-out"] = opts.idle_timeout + if opts.username: + conn_properties['x-username'] = opts.username + if opts.password: + conn_properties['x-password'] = opts.password + if opts.sasl_mechs: + conn_properties['x-sasl-mechs'] = opts.sasl_mechs + c_handler = pyngus.ConnectionEventHandler() connection = container.create_connection("receiver", - None, # no events + c_handler, conn_properties) connection.open() - class ReceiveCallback(pyngus.ReceiverEventHandler): - def __init__(self): - self.done = False - self.message = None - self.handle = None - - def message_received(self, receiver, message, handle): - self.done = True - self.message = message - self.handle = handle - target_address = opts.target_addr or uuid.uuid4().hex - cb = ReceiveCallback() + cb = ReceiverEventHandler() receiver = connection.create_receiver(target_address, opts.source_addr, cb) @@ -104,8 +139,10 @@ def main(argv=None): process_connection(connection, my_socket) if cb.done: - print("Receive done, message=%s" % str(cb.message)) - receiver.message_accepted(cb.handle) + print("Receive done, message=%s" % str(cb.message) if cb.message + else "ERROR: no message received") + if cb.handle: + receiver.message_accepted(cb.handle) else: print("Receive failed due to connection failure!") diff --git a/examples/rpc-client.py b/examples/rpc-client.py index 65dec96..c594299 100755 --- a/examples/rpc-client.py +++ b/examples/rpc-client.py @@ -93,12 +93,23 @@ class MyConnection(pyngus.ConnectionEventHandler): LOG.debug("select() returned") if readable: - pyngus.read_socket_input(self.connection, - self.socket) + try: + pyngus.read_socket_input(self.connection, + self.socket) + except Exception as e: + LOG.error("Exception on socket read: %s", str(e)) + self.connection.close_input() + self.connection.close() + self.connection.process(time.time()) if writable: - pyngus.write_socket_output(self.connection, - self.socket) + try: + pyngus.write_socket_output(self.connection, + self.socket) + except Exception as e: + LOG.error("Exception on socket write: %s", str(e)) + self.connection.close_output() + self.connection.close() def close(self, error=None): self.connection.close(error) @@ -131,6 +142,7 @@ class MyConnection(pyngus.ConnectionEventHandler): def connection_remote_closed(self, connection, reason): LOG.debug("Connection remote closed callback") + connection.close() def connection_closed(self, connection): LOG.debug("Connection closed callback") diff --git a/examples/rpc-server.py b/examples/rpc-server.py index 3d7951d..5dad460 100755 --- a/examples/rpc-server.py +++ b/examples/rpc-server.py @@ -53,6 +53,9 @@ LOG.addHandler(logging.StreamHandler()) sender_links = {} receiver_links = {} +# links that have closed and need to be destroyed: +dead_links = set() + # Map reply-to address to the proper sending link (indexed by address) reply_senders = {} @@ -92,10 +95,8 @@ class SocketConnection(pyngus.ConnectionEventHandler): pyngus.read_socket_input(self.connection, self.socket) except Exception as e: LOG.error("Exception on socket read: %s", str(e)) - # may be redundant if closed cleanly: - self.connection_closed(self.connection) - return - + self.connection.close_input() + self.connection.close() self.connection.process(time.time()) def send_output(self): @@ -105,10 +106,8 @@ class SocketConnection(pyngus.ConnectionEventHandler): self.socket) except Exception as e: LOG.error("Exception on socket write: %s", str(e)) - # may be redundant if closed cleanly: - self.connection_closed(self.connection) - return - + self.connection.close_output() + self.connection.close() self.connection.process(time.time()) # ConnectionEventHandler callbacks: @@ -198,6 +197,12 @@ class MySenderLink(pyngus.SenderEventHandler): properties) self.sender_link.open() + @property + def closed(self): + if self.sender_link: + return self.sender_link.closed + return True + # SenderEventHandler callbacks: def sender_active(self, sender_link): @@ -211,12 +216,14 @@ class MySenderLink(pyngus.SenderEventHandler): LOG.debug("sender closed callback") global sender_links global reply_senders + global dead_links if self._ident in sender_links: del sender_links[self._ident] if self._source_address in reply_senders: del reply_senders[self._source_address] - self.sender_link.destroy() + + dead_links.add(self.sender_link) self.sender_link = None # 'message sent' callback: @@ -238,6 +245,12 @@ class MyReceiverLink(pyngus.ReceiverEventHandler): properties) self._link.open() + @property + def closed(self): + if self._link: + return self._link.closed + return True + # ReceiverEventHandler callbacks: def receiver_active(self, receiver_link): LOG.debug("receiver active callback") @@ -250,10 +263,12 @@ class MyReceiverLink(pyngus.ReceiverEventHandler): def receiver_closed(self, receiver_link): LOG.debug("receiver closed callback") global receiver_links + global dead_links if self._ident in receiver_links: del receiver_links[self._ident] - self._link.destroy() + + dead_links.add(self._link) self._link = None def message_received(self, receiver_link, message, handle): @@ -353,6 +368,7 @@ def main(argv=None): # container = pyngus.Container("example RPC service") global socket_connections + global dead_links while True: @@ -428,7 +444,11 @@ def main(argv=None): w.send_output() worked.append(w) - # nuke any completed connections: + # first, free any closed links: + while dead_links: + dead_links.pop().destroy() + + # then nuke any completed connections: closed = False while worked: sc = worked.pop() diff --git a/examples/send.py b/examples/send.py index e656280..4ffeade 100755 --- a/examples/send.py +++ b/examples/send.py @@ -35,6 +35,29 @@ LOG = logging.getLogger() LOG.addHandler(logging.StreamHandler()) +class ConnectionEventHandler(pyngus.ConnectionEventHandler): + def connection_failed(self, connection, error): + """Connection's transport has failed in some way.""" + LOG.warn("Connection failed: %s", error) + connection.close() + + def connection_remote_closed(self, connection, pn_condition): + """Peer has closed its end of the connection.""" + LOG.debug("connection_remote_closed condition=%s", pn_condition) + connection.close() + + +class SenderEventHandler(pyngus.SenderEventHandler): + def sender_remote_closed(self, sender_link, pn_condition): + LOG.debug("Sender peer_closed condition=%s", pn_condition) + sender_link.close() + + def sender_failed(self, sender_link, error): + """Protocol error occurred.""" + LOG.debug("Sender failed error=%s", error) + sender_link.close() + + def main(argv=None): _usage = """Usage: %prog [options] [message content string]""" @@ -55,6 +78,12 @@ def main(argv=None): help="enable protocol tracing") parser.add_option("--ca", help="Certificate Authority PEM file") + parser.add_option("--username", type="string", + help="User Id for authentication") + parser.add_option("--password", type="string", + help="User password for authentication") + parser.add_option("--sasl-mechs", type="string", + help="The list of acceptable SASL mechs") opts, payload = parser.parse_args(args=argv) if not payload: @@ -69,23 +98,31 @@ def main(argv=None): # container = pyngus.Container(uuid.uuid4().hex) conn_properties = {'hostname': host, - 'x-server': False, - 'x-sasl-mechs': "ANONYMOUS PLAIN"} + 'x-server': False} if opts.trace: conn_properties["x-trace-protocol"] = True if opts.ca: conn_properties["x-ssl-ca-file"] = opts.ca if opts.idle_timeout: conn_properties["idle-time-out"] = opts.idle_timeout + if opts.username: + conn_properties['x-username'] = opts.username + if opts.password: + conn_properties['x-password'] = opts.password + if opts.sasl_mechs: + conn_properties['x-sasl-mechs'] = opts.sasl_mechs + c_handler = ConnectionEventHandler() connection = container.create_connection("sender", - None, # no events + c_handler, conn_properties) connection.open() source_address = opts.source_addr or uuid.uuid4().hex + s_handler = SenderEventHandler() sender = connection.create_sender(source_address, - opts.target_addr) + opts.target_addr, + s_handler) sender.open() # Send a single message: diff --git a/examples/server.py b/examples/server.py index 39dd9b7..dbacf8c 100755 --- a/examples/server.py +++ b/examples/server.py @@ -47,13 +47,11 @@ class SocketConnection(pyngus.ConnectionEventHandler): properties) self.connection.user_context = self self.connection.open() - self.closed = False self.sender_links = set() self.receiver_links = set() def destroy(self): - self.closed = True for link in self.sender_links.copy(): link.destroy() for link in self.receiver_links.copy(): @@ -65,6 +63,10 @@ class SocketConnection(pyngus.ConnectionEventHandler): self.socket.close() self.socket = None + @property + def closed(self): + return self.connection == None or self.connection.closed + def fileno(self): """Allows use of a SocketConnection in a select() call.""" return self.socket.fileno() @@ -75,9 +77,8 @@ class SocketConnection(pyngus.ConnectionEventHandler): pyngus.read_socket_input(self.connection, self.socket) except Exception as e: LOG.error("Exception on socket read: %s", str(e)) - # may be redundant if closed cleanly: - self.connection_closed(self.connection) - return + self.connection.close_input() + self.connection.close() self.connection.process(time.time()) def send_output(self): @@ -87,9 +88,8 @@ class SocketConnection(pyngus.ConnectionEventHandler): self.socket) except Exception as e: LOG.error("Exception on socket write: %s", str(e)) - # may be redundant if closed cleanly: - self.connection_closed(self.connection) - return + self.connection.close_output() + self.connection.close() self.connection.process(time.time()) # ConnectionEventHandler callbacks: @@ -102,13 +102,10 @@ class SocketConnection(pyngus.ConnectionEventHandler): def connection_closed(self, connection): LOG.debug("Connection: closed.") - # main loop will destroy - self.closed = True def connection_failed(self, connection, error): LOG.error("Connection: failed! error=%s", str(error)) - # No special recovery - just simulate close completed: - self.connection_closed(connection) + self.connection.close() def sender_requested(self, connection, link_handle, name, requested_source, properties): @@ -152,6 +149,10 @@ class MySenderLink(pyngus.SenderEventHandler): self.sender_link.open() print("New Sender link created, name=%s" % sl.name) + @property + def closed(self): + return self.sender_link.closed + def destroy(self): print("Sender link destroyed, name=%s" % self.sender_link.name) self.socket_conn.sender_links.discard(self) @@ -176,11 +177,6 @@ class MySenderLink(pyngus.SenderEventHandler): LOG.debug("Sender: Remote closed") self.sender_link.close() - def sender_closed(self, sender_link): - LOG.debug("Sender: Closed") - # Done with this sender: - self.destroy() - def credit_granted(self, sender_link): LOG.debug("Sender: credit granted") # Send a single message: @@ -208,6 +204,10 @@ class MyReceiverLink(pyngus.ReceiverEventHandler): self.receiver_link.add_capacity(1) print("New Receiver link created, name=%s" % rl.name) + @property + def closed(self): + return self.receiver_link.closed + def destroy(self): print("Receiver link destroyed, name=%s" % self.receiver_link.name) self.socket_conn.receiver_links.discard(self) @@ -224,11 +224,6 @@ class MyReceiverLink(pyngus.ReceiverEventHandler): LOG.debug("Receiver: Remote closed") self.receiver_link.close() - def receiver_closed(self, receiver_link): - LOG.debug("Receiver: Closed") - # Done with this Receiver: - self.destroy() - def message_received(self, receiver_link, message, handle): self.receiver_link.message_accepted(handle) print("Message received on Receiver link %s, message=%s" @@ -256,6 +251,14 @@ def main(argv=None): help="PEM File containing the server's private key") parser.add_option("--keypass", help="Password used to decrypt key file") + parser.add_option("--require-auth", action="store_true", + help="Require clients to authenticate") + parser.add_option("--sasl-mechs", type="string", + help="The list of acceptable SASL mechs") + parser.add_option("--sasl-cfg-name", type="string", + help="name of SASL config file (no suffix)") + parser.add_option("--sasl-cfg-dir", type="string", + help="Path to the SASL config file") opts, arguments = parser.parse_args(args=argv) if opts.debug: @@ -300,9 +303,15 @@ def main(argv=None): client_socket, client_address = my_socket.accept() # name = uuid.uuid4().hex name = str(client_address) - conn_properties = {'x-server': True, - 'x-require-auth': False, - 'x-sasl-mechs': "ANONYMOUS"} + conn_properties = {'x-server': True} + if opts.require_auth: + conn_properties['x-require-auth'] = True + if opts.sasl_mechs: + conn_properties['x-sasl-mechs'] = opts.sasl_mechs + if opts.sasl_cfg_name: + conn_properties['x-sasl-config-name'] = opts.sasl_cfg_name + if opts.sasl_cfg_dir: + conn_properties['x-sasl-config-dir'] = opts.sasl_cfg_dir if opts.idle_timeout: conn_properties["idle-time-out"] = opts.idle_timeout if opts.trace: @@ -337,14 +346,20 @@ def main(argv=None): w.send_output() worked.add(w) - # nuke any completed connections: closed = False while worked: sc = worked.pop() + # nuke any completed connections: if sc.closed: socket_connections.discard(sc) sc.destroy() closed = True + else: + # can free any closed links now (optional): + for link in sc.sender_links | sc.receiver_links: + if link.closed: + link.destroy() + if closed: LOG.debug("%d active connections present", len(socket_connections)) diff --git a/examples/utils.py b/examples/utils.py index c811729..aec9eba 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -19,6 +19,7 @@ """Utilities used by the Examples""" import errno +import logging import re import socket import select @@ -26,6 +27,7 @@ import time import pyngus +LOG = logging.getLogger() def get_host_port(server_address): """Parse the hostname and port out of the server_address.""" @@ -103,10 +105,24 @@ def process_connection(connection, my_socket): [], timeout) if readable: - pyngus.read_socket_input(connection, my_socket) + try: + pyngus.read_socket_input(connection, my_socket) + except Exception as e: + # treat any socket error as + LOG.error("Socket error on read: %s", str(e)) + connection.close_input() + # make an attempt to cleanly close + connection.close() + connection.process(time.time()) if writable: - pyngus.write_socket_output(connection, my_socket) + try: + pyngus.write_socket_output(connection, my_socket) + except Exception as e: + LOG.error("Socket error on write %s", str(e)) + connection.close_output() + # this may not help, but it won't hurt: + connection.close() return True # Map the send callback status to a string diff --git a/pyngus/connection.py b/pyngus/connection.py index e070bb1..7d79c23 100644 --- a/pyngus/connection.py +++ b/pyngus/connection.py @@ -35,6 +35,24 @@ LOG = logging.getLogger(__name__) _PROTON_VERSION = (int(getattr(proton, "VERSION_MAJOR", 0)), int(getattr(proton, "VERSION_MINOR", 0))) +class _CallbackLock(object): + """A utility class for detecting when a callback invokes a non-reentrant + Pyngus method. + """ + def __init__(self): + super(_CallbackLock, self).__init__() + self.in_callback = 0 + + def __enter__(self): + self.in_callback += 1 + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.in_callback -= 1 + # if a call is made to a non-reentrant method while this context is + # held, then the method will raise a RuntimeError(). Return false to + # propagate the exception to the caller + return False class ConnectionEventHandler(object): """An implementation of an AMQP 1.0 Connection.""" @@ -89,6 +107,17 @@ class Connection(Endpoint): 'x-sasl-mechs', 'x-sasl-config-dir', 'x-sasl-config-name']) + def _not_reentrant(func): + """Decorator that prevents callbacks from calling into methods that are + not reentrant + """ + def wrap(self, *args, **kws): + if self._callback_lock.in_callback: + m = "Connection %s cannot be invoked from a callback!" % func + raise RuntimeError(m) + return func(self, *args, **kws) + return wrap + def __init__(self, container, name, event_handler=None, properties=None): """Create a new connection from the Container @@ -213,8 +242,8 @@ class Connection(Endpoint): self._error = None self._next_deadline = 0 self._user_context = None - self._in_process = False self._remote_session_id = 0 + self._callback_lock = _CallbackLock() self._pn_sasl = None self._sasl_done = False @@ -242,19 +271,21 @@ class Connection(Endpoint): else: # new Proton SASL configuration if 'x-require-auth' in self._properties: - self._init_sasl() ra = self._properties['x-require-auth'] self._pn_transport.require_auth(ra) if 'x-username' in self._properties: - self._init_sasl() + # maintain old behavior: allow PLAIN and ANONYMOUS + # authentication. Override this using x-sasl-mechs below: + self.pn_sasl.allow_insecure_mechs = True self._pn_connection.user = self._properties['x-username'] if 'x-password' in self._properties: - self._init_sasl() self._pn_connection.password = \ self._properties['x-password'] if 'x-sasl-mechs' in self._properties: - self.pn_sasl.allowed_mechs( - self._properties['x-sasl-mechs']) + mechs = self._properties['x-sasl-mechs'].upper() + self.pn_sasl.allowed_mechs(mechs) + if 'PLAIN' not in mechs and 'ANONYMOUS' not in mechs: + self.pn_sasl.allow_insecure_mechs = False if 'x-sasl-config-dir' in self._properties: self.pn_sasl.config_path( self._properties['x-sasl-config-dir']) @@ -309,13 +340,10 @@ class Connection(Endpoint): return self._pn_connection.remote_properties return None - def _init_sasl(self): - if not self._pn_sasl: - self._pn_sasl = self._pn_transport.sasl() - @property def pn_sasl(self): - self._init_sasl() + if not self._pn_sasl: + self._pn_sasl = self._pn_transport.sasl() return self._pn_sasl def pn_ssl(self): @@ -359,8 +387,13 @@ class Connection(Endpoint): """Return True if the Connection has finished closing.""" return (self._write_done and self._read_done) + @_not_reentrant def destroy(self): - self._error = "Destroyed" + # if a connection is destroyed without flushing pending output, + # the remote will see an unclean shutdown (framing error) + if self.has_output > 0: + LOG.debug("Connection with buffered output destroyed") + self._error = "Destroyed by the application" self._handler = None self._sender_links.clear() self._receiver_links.clear() @@ -382,79 +415,80 @@ class Connection(Endpoint): _CLOSED = (proton.Endpoint.LOCAL_CLOSED | proton.Endpoint.REMOTE_CLOSED) _ACTIVE = (proton.Endpoint.LOCAL_ACTIVE | proton.Endpoint.REMOTE_ACTIVE) + @_not_reentrant def process(self, now): """Perform connection state processing.""" - if self._in_process: - raise RuntimeError("Connection.process() is not re-entrant!") - self._in_process = True - try: - # if the connection has hit an unrecoverable error, - # nag the application until connection is destroyed - if self._error: - if self._handler: - self._handler.connection_failed(self, self._error) - # nag application until connection is destroyed - self._next_deadline = now - return now + if self._pn_connection is None: + LOG.error("Connection.process() called on destroyed connection!") + return 0 - # do nothing until the connection has been opened - if self._pn_connection.state & proton.Endpoint.LOCAL_UNINIT: - return 0 + # do nothing until the connection has been opened + if self._pn_connection.state & proton.Endpoint.LOCAL_UNINIT: + return 0 + if self._pn_sasl and not self._sasl_done: + # wait until SASL has authenticated if (_PROTON_VERSION < (0, 10)): - # wait until SASL has authenticated - if self._pn_sasl and not self._sasl_done: - if self._pn_sasl.state not in (proton.SASL.STATE_PASS, - proton.SASL.STATE_FAIL): - LOG.debug("SASL in progress. State=%s", - str(self._pn_sasl.state)) - if self._handler: - self._handler.sasl_step(self, self._pn_sasl) - return self._next_deadline - - self._sasl_done = True + if self._pn_sasl.state not in (proton.SASL.STATE_PASS, + proton.SASL.STATE_FAIL): + LOG.debug("SASL in progress. State=%s", + str(self._pn_sasl.state)) if self._handler: + with self._callback_lock: + self._handler.sasl_step(self, self._pn_sasl) + return self._next_deadline + + self._sasl_done = True + if self._handler: + with self._callback_lock: self._handler.sasl_done(self, self._pn_sasl, self._pn_sasl.outcome) - - # process timer events: - timer_deadline = self._expire_timers(now) - transport_deadline = self._pn_transport.tick(now) - if timer_deadline and transport_deadline: - self._next_deadline = min(timer_deadline, transport_deadline) else: - self._next_deadline = timer_deadline or transport_deadline + if self._pn_sasl.outcome is not None: + self._sasl_done = True + if self._handler: + with self._callback_lock: + self._handler.sasl_done(self, self._pn_sasl, + self._pn_sasl.outcome) - # process events from proton: + # process timer events: + timer_deadline = self._expire_timers(now) + transport_deadline = self._pn_transport.tick(now) + if timer_deadline and transport_deadline: + self._next_deadline = min(timer_deadline, transport_deadline) + else: + self._next_deadline = timer_deadline or transport_deadline + + # process events from proton: + pn_event = self._pn_collector.peek() + while pn_event: + LOG.debug("pn_event: %s received", pn_event.type) + if self._handle_proton_event(pn_event): + pass + elif _SessionProxy._handle_proton_event(pn_event, self): + pass + elif _Link._handle_proton_event(pn_event, self): + pass + self._pn_collector.pop() pn_event = self._pn_collector.peek() - while pn_event: - if self._handle_proton_event(pn_event): - pass - elif _SessionProxy._handle_proton_event(pn_event, self): - pass - elif _Link._handle_proton_event(pn_event, self): - pass - self._pn_collector.pop() - pn_event = self._pn_collector.peek() - # re-check for connection failure after processing all pending - # engine events: - if self._error: - if self._handler: + # check for connection failure after processing all pending + # engine events: + if self._error: + if self._handler: + # nag application until connection is destroyed + self._next_deadline = now + with self._callback_lock: self._handler.connection_failed(self, self._error) - # nag application until connection is destroyed - self._next_deadline = now - elif (self._endpoint_state == self._CLOSED and - self._read_done and self._write_done): - # invoke closed callback after endpoint has fully closed and - # all pending I/O has completed: - if self._handler: + elif (self._endpoint_state == self._CLOSED and + self._read_done and self._write_done): + # invoke closed callback after endpoint has fully closed and + # all pending I/O has completed: + if self._handler: + with self._callback_lock: self._handler.connection_closed(self) - return self._next_deadline - - finally: - self._in_process = False + return self._next_deadline @property def next_tick(self): @@ -470,14 +504,17 @@ class Connection(Endpoint): @property def needs_input(self): if self._read_done: + LOG.debug("needs_input EOS") return self.EOS try: - # TODO(grs): can this actually throw? capacity = self._pn_transport.capacity() except Exception as e: - return self._connection_failed(str(e)) + self._read_done = True + self._connection_failed(str(e)) + return self.EOS if capacity >= 0: return capacity + LOG.debug("needs_input read done") self._read_done = True return self.EOS @@ -486,10 +523,14 @@ class Connection(Endpoint): if c <= 0: return c try: + LOG.debug("pushing %s bytes to transport:", c) rc = self._pn_transport.push(in_data[:c]) except Exception as e: - return self._connection_failed(str(e)) + self._read_done = True + self._connection_failed(str(e)) + return self.EOS if rc: # error? + LOG.debug("process_input read done") self._read_done = True return self.EOS # hack: check if this was the last input needed by the connection. @@ -504,18 +545,23 @@ class Connection(Endpoint): self._pn_transport.close_tail() except Exception as e: self._connection_failed(str(e)) + LOG.debug("close_input read done") self._read_done = True @property def has_output(self): if self._write_done: + LOG.debug("has output EOS") return self.EOS try: pending = self._pn_transport.pending() except Exception as e: - return self._connection_failed(str(e)) + self._write_done = True + self._connection_failed(str(e)) + return self.EOS if pending >= 0: return pending + LOG.debug("has output write_done") self._write_done = True return self.EOS @@ -526,6 +572,7 @@ class Connection(Endpoint): if c <= 0: return None try: + LOG.debug("Getting %s bytes output from transport", c) buf = self._pn_transport.peek(c) except Exception as e: self._connection_failed(str(e)) @@ -534,9 +581,11 @@ class Connection(Endpoint): def output_written(self, count): try: + LOG.debug("Popping %s bytes output from transport", count) self._pn_transport.pop(count) except Exception as e: - return self._connection_failed(str(e)) + self._write_done = True + self._connection_failed(str(e)) # hack: check if this was the last output from the connection. If so, # this will set the _write_done flag and the 'connection closed' # callback can be issued on the next call to process() @@ -548,6 +597,7 @@ class Connection(Endpoint): self._pn_transport.close_head() except Exception as e: self._connection_failed(str(e)) + LOG.debug("close output write done") self._write_done = True def create_sender(self, source_address, target_address=None, @@ -584,6 +634,9 @@ class Connection(Endpoint): if not link: raise Exception("Invalid link_handle: %s" % link_handle) link.reject(pn_condition) + # note: normally, link.destroy() cannot be called from a callback, + # but this link was never made available to the application so this + # link is only referenced by the connection link.destroy() def create_receiver(self, target_address, source_address=None, @@ -619,6 +672,9 @@ class Connection(Endpoint): if not link: raise Exception("Invalid link_handle: %s" % link_handle) link.reject(pn_condition) + # note: normally, link.destroy() cannot be called from a callback, + # but this link was never made available to the application so this + # link is only referenced by the connection link.destroy() @property @@ -637,12 +693,7 @@ class Connection(Endpoint): """Clean up after connection failure detected.""" if not self._error: LOG.error("Connection failed: %s", str(error)) - self._read_done = True - self._write_done = True self._error = error - # report error during the next call to process() - self._next_deadline = time.time() - return self.EOS def _configure_ssl(self, properties): if not properties: @@ -759,20 +810,16 @@ class Connection(Endpoint): """Both ends of the Endpoint have become active.""" LOG.debug("Connection is up") if self._handler: - if (_PROTON_VERSION >= (0, 10)): - # simulate the old sasl_done callback - if self._pn_sasl and not self._sasl_done: - self._sasl_done = True - self._handler.sasl_done(self, self._pn_sasl, - self._pn_sasl.outcome) - self._handler.connection_active(self) + with self._callback_lock: + self._handler.connection_active(self) def _ep_need_close(self): """The remote has closed its end of the endpoint.""" LOG.debug("Connection remotely closed") if self._handler: cond = self._pn_connection.remote_condition - self._handler.connection_remote_closed(self, cond) + with self._callback_lock: + self._handler.connection_remote_closed(self, cond) def _ep_error(self, error): """The endpoint state machine failed due to protocol error.""" diff --git a/pyngus/link.py b/pyngus/link.py index 9629af5..18f1898 100644 --- a/pyngus/link.py +++ b/pyngus/link.py @@ -42,6 +42,40 @@ _snd_settle_modes = {"settled": proton.Link.SND_SETTLED, _rcv_settle_modes = {"first": proton.Link.RCV_FIRST, "second": proton.Link.RCV_SECOND} +#TODO(kgiusti): this is duplicated in connection.py, put in common file +class _CallbackLock(object): + """A utility class for detecting when a callback invokes a non-reentrant + Pyngus method. + """ + def __init__(self, link): + super(_CallbackLock, self).__init__() + self._link = link + self.in_callback = 0 + + def __enter__(self): + # manually lock parent - can't enter its non-reentrant methods + self._link._connection._callback_lock.__enter__() + self.in_callback += 1 + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.in_callback -= 1 + self._link._connection._callback_lock.__exit__(None, None, None) + # if a call is made to a non-reentrant method while this context is + # held, then the method will raise a RuntimeError(). Return false to + # propagate the exception to the caller + return False + +def _not_reentrant(func): + """Decorator that prevents callbacks from calling into link methods that + are not reentrant """ + def wrap(*args, **kws): + link = args[0] + if link._callback_lock.in_callback: + m = "Link %s cannot be invoked from a callback!" % func + raise RuntimeError(m) + return func(*args, **kws) + return wrap class _Link(Endpoint): """A generic Link base class.""" @@ -54,6 +88,7 @@ class _Link(Endpoint): self._user_context = None self._rejected = False # requested link was refused self._failed = False # protocol error occurred + self._callback_lock = _CallbackLock(self) # TODO(kgiusti): raise jira to add 'context' attr to api self._pn_link = pn_link pn_link.context = self @@ -213,10 +248,12 @@ class _Link(Endpoint): if pn_event.type == proton.Event.DELIVERY: pn_delivery = pn_event.context pn_link = pn_delivery.link - pn_link.context._process_delivery(pn_delivery) + if pn_link.context: + pn_link.context._process_delivery(pn_delivery) elif pn_event.type == proton.Event.LINK_FLOW: pn_link = pn_event.context - pn_link.context._process_credit() + if pn_link.context: + pn_link.context._process_credit() elif ep_event is not None: pn_link = pn_event.context if pn_link.context: @@ -310,7 +347,8 @@ class SenderEventHandler(object): LOG.debug("sender_active (ignored)") def sender_remote_closed(self, sender_link, pn_condition): - LOG.debug("sender_remote_closed (ignored)") + LOG.debug("sender_remote_closed condition=%s (ignored)", + pn_condition) def sender_closed(self, sender_link): LOG.debug("sender_closed (ignored)") @@ -320,7 +358,7 @@ class SenderEventHandler(object): def sender_failed(self, sender_link, error): """Protocol error occurred.""" - LOG.debug("sender_failed (ignored)") + LOG.debug("sender_failed error=%s (ignored)", error) class SenderLink(_Link): @@ -366,7 +404,8 @@ class SenderLink(_Link): if self.tag in self.link._send_requests: del self.link._send_requests[self.tag] if self.callback: - self.callback(self.link, self.handle, state, info) + with self.link._callback_lock: + self.callback(self.link, self.handle, state, info) def __init__(self, connection, pn_link): super(SenderLink, self).__init__(connection, pn_link) @@ -419,6 +458,7 @@ class SenderLink(_Link): self._pn_link.source.type = proton.Terminus.UNSPECIFIED super(SenderLink, self).reject(pn_condition) + @_not_reentrant def destroy(self): self._connection._remove_sender(self._name) self._connection = None @@ -476,7 +516,8 @@ class SenderLink(_Link): if self._handler and not self._rejected: if self._last_credit <= 0 and new_credit > 0: LOG.debug("Credit is available, link=%s", self.name) - self._handler.credit_granted(self) + with self._callback_lock: + self._handler.credit_granted(self) self._last_credit = new_credit def _write_msg(self, pn_delivery, send_req): @@ -501,20 +542,23 @@ class SenderLink(_Link): def _link_failed(self, error): if self._handler and not self._rejected: - self._handler.sender_failed(self, error) + with self._callback_lock: + self._handler.sender_failed(self, error) # endpoint state machine actions: def _ep_active(self): LOG.debug("SenderLink is up") if self._handler and not self._rejected: - self._handler.sender_active(self) + with self._callback_lock: + self._handler.sender_active(self) def _ep_need_close(self): LOG.debug("SenderLink remote closed") if self._handler and not self._rejected: cond = self._pn_link.remote_condition - self._handler.sender_remote_closed(self, cond) + with self._callback_lock: + self._handler.sender_remote_closed(self, cond) def _ep_closed(self): LOG.debug("SenderLink close completed") @@ -526,7 +570,8 @@ class SenderLink(_Link): key, send_req = self._send_requests.popitem() send_req.destroy(SenderLink.ABORTED, info) if self._handler and not self._rejected: - self._handler.sender_closed(self) + with self._callback_lock: + self._handler.sender_closed(self) def _ep_requested(self): LOG.debug("Remote has requested a SenderLink") @@ -551,11 +596,12 @@ class SenderLink(_Link): elif (dist_mode == proton.Terminus.DIST_MODE_MOVE): props["distribution-mode"] = "move" - handler.sender_requested(self._connection, - pn_link.name, # handle - pn_link.name, - req_source, - props) + with self._connection._callback_lock: + handler.sender_requested(self._connection, + pn_link.name, # handle + pn_link.name, + req_source, + props) class ReceiverEventHandler(object): @@ -564,14 +610,15 @@ class ReceiverEventHandler(object): LOG.debug("receiver_active (ignored)") def receiver_remote_closed(self, receiver_link, pn_condition): - LOG.debug("receiver_remote_closed (ignored)") + LOG.debug("receiver_remote_closed condition=%s (ignored)", + pn_condition) def receiver_closed(self, receiver_link): LOG.debug("receiver_closed (ignored)") def receiver_failed(self, receiver_link, error): """Protocol error occurred.""" - LOG.debug("receiver_failed (ignored)") + LOG.debug("receiver_failed error=%s (ignored)", error) def message_received(self, receiver_link, message, handle): LOG.debug("message_received (ignored)") @@ -631,6 +678,7 @@ class ReceiverLink(_Link): self._pn_link.target.type = proton.Terminus.UNSPECIFIED super(ReceiverLink, self).reject(pn_condition) + @_not_reentrant def destroy(self): self._connection._remove_receiver(self._name) self._connection = None @@ -651,7 +699,8 @@ class ReceiverLink(_Link): handle = "rmsg-%s:%x" % (self._name, self._next_handle) self._next_handle += 1 self._unsettled_deliveries[handle] = pn_delivery - self._handler.message_received(self, msg, handle) + with self._callback_lock: + self._handler.message_received(self, msg, handle) else: # TODO(kgiusti): is it ok to assume Delivery.REJECTED? pn_delivery.settle() @@ -662,25 +711,29 @@ class ReceiverLink(_Link): def _link_failed(self, error): if self._handler and not self._rejected: - self._handler.receiver_failed(self, error) + with self._callback_lock: + self._handler.receiver_failed(self, error) # endpoint state machine actions: def _ep_active(self): LOG.debug("ReceiverLink is up") if self._handler and not self._rejected: - self._handler.receiver_active(self) + with self._callback_lock: + self._handler.receiver_active(self) def _ep_need_close(self): LOG.debug("ReceiverLink remote closed") if self._handler and not self._rejected: cond = self._pn_link.remote_condition - self._handler.receiver_remote_closed(self, cond) + with self._callback_lock: + self._handler.receiver_remote_closed(self, cond) def _ep_closed(self): LOG.debug("ReceiverLink close completed") if self._handler and not self._rejected: - self._handler.receiver_closed(self) + with self._callback_lock: + self._handler.receiver_closed(self) def _ep_requested(self): LOG.debug("Remote has initiated a ReceiverLink") @@ -705,11 +758,12 @@ class ReceiverLink(_Link): elif (dist_mode == proton.Terminus.DIST_MODE_MOVE): props["distribution-mode"] = "move" - handler.receiver_requested(self._connection, - pn_link.name, # handle - pn_link.name, - req_target, - props) + with self._connection._callback_lock: + handler.receiver_requested(self._connection, + pn_link.name, # handle + pn_link.name, + req_target, + props) class _SessionProxy(Endpoint): diff --git a/pyngus/sockets.py b/pyngus/sockets.py index 051566f..f176038 100644 --- a/pyngus/sockets.py +++ b/pyngus/sockets.py @@ -42,32 +42,35 @@ def read_socket_input(connection, socket_obj): if count <= 0: return count # 0 or EOS - try: - sock_data = socket_obj.recv(count) - except socket.timeout as e: - LOG.debug("Socket timeout exception %s", str(e)) - raise # caller must handle - except socket.error as e: - LOG.debug("Socket error exception %s", str(e)) - err = e.errno - # ignore non-fatal errors - if (err != errno.EAGAIN and - err != errno.EWOULDBLOCK and - err != errno.EINTR): - # otherwise, unrecoverable: - connection.close_input() - raise # caller must handle - except Exception as e: # beats me... assume fatal - LOG.debug("unknown socket exception %s", str(e)) - connection.close_input() - raise # caller must handle + while True: + try: + sock_data = socket_obj.recv(count) + break + except socket.timeout as e: + LOG.debug("Socket timeout exception %s", str(e)) + raise # caller must handle + except socket.error as e: + LOG.debug("Socket error exception %s", str(e)) + err = e.errno + if err in [errno.EAGAIN, + errno.EWOULDBLOCK, + errno.EINTR]: + # try again later + return 0 + # otherwise, unrecoverable, caller must handle + raise + except Exception as e: # beats me... assume fatal + LOG.debug("unknown socket exception %s", str(e)) + raise # caller must handle - if sock_data: + if len(sock_data) > 0: count = connection.process_input(sock_data) else: LOG.debug("Socket closed") count = Connection.EOS connection.close_input() + connection.close_output() + LOG.debug("Socket recv %s bytes", count) return count @@ -85,30 +88,34 @@ def write_socket_output(connection, socket_obj): if not data: # error - has_output > 0, but no data? return Connection.EOS - try: - count = socket_obj.send(data) - except socket.timeout as e: - LOG.debug("Socket timeout exception %s", str(e)) - raise # caller must handle - except socket.error as e: - LOG.debug("Socket error exception %s", str(e)) - err = e.errno - # ignore non-fatal errors - if (err != errno.EAGAIN and - err != errno.EWOULDBLOCK and - err != errno.EINTR): - # otherwise, unrecoverable - connection.close_output() - raise - except Exception as e: # beats me... assume fatal - LOG.debug("unknown socket exception %s", str(e)) - connection.close_output() - raise + + while True: + try: + count = socket_obj.send(data) + break + except socket.timeout as e: + LOG.debug("Socket timeout exception %s", str(e)) + raise # caller must handle + except socket.error as e: + LOG.debug("Socket error exception %s", str(e)) + err = e.errno + if err in [errno.EAGAIN, + errno.EWOULDBLOCK, + errno.EINTR]: + # try again later + return 0 + # else assume fatal let caller handle it: + raise + except Exception as e: # beats me... assume fatal + LOG.debug("unknown socket exception %s", str(e)) + raise if count > 0: + LOG.debug("Socket sent %s bytes", count) connection.output_written(count) elif data: LOG.debug("Socket closed") count = Connection.EOS connection.close_output() + connection.close_input() return count diff --git a/setup.py b/setup.py index 03f9a68..2cb2ed2 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ from setuptools import setup -_VERSION = "2.0.0a2" # NOTE: update __init__.py too! +_VERSION = "2.0.0a3" # NOTE: update __init__.py too! # I hack, therefore I am (productive) Some distros (which will not be named) # don't use setup.py to install the proton python module. In this case, pip diff --git a/tests/unit_tests/common.py b/tests/unit_tests/common.py index 99b1c7b..45d1d08 100644 --- a/tests/unit_tests/common.py +++ b/tests/unit_tests/common.py @@ -102,9 +102,14 @@ def process_connections(c1, c2, timestamp=None): c2.process(timestamp) -def _validate_callback(connection): - """Callbacks must only occur from the Connection.process() call.""" - assert connection._in_process +def _validate_conn_callback(connection): + """Callbacks must only occur when holding the Connection callback lock""" + assert connection._callback_lock.in_callback, \ + connection._callback_lock.in_callback + +def _validate_link_callback(link): + """Callbacks must only occur when holding the Link callback lock.""" + assert link._callback_lock.in_callback, link._callback_lock.in_callback class ConnCallback(pyngus.ConnectionEventHandler): @@ -131,29 +136,30 @@ class ConnCallback(pyngus.ConnectionEventHandler): self.sasl_step_ct = 0 self.sasl_done_ct = 0 + self.sasl_done_outcome = None def connection_active(self, connection): - _validate_callback(connection) + _validate_conn_callback(connection) self.active_ct += 1 def connection_failed(self, connection, error): - _validate_callback(connection) + _validate_conn_callback(connection) self.failed_ct += 1 self.failed_error = error def connection_remote_closed(self, connection, error=None): - _validate_callback(connection) + _validate_conn_callback(connection) self.remote_closed_ct += 1 self.remote_closed_error = error def connection_closed(self, connection): - _validate_callback(connection) + _validate_conn_callback(connection) self.closed_ct += 1 def sender_requested(self, connection, link_handle, name, requested_source, properties): - _validate_callback(connection) + _validate_conn_callback(connection) self.sender_requested_ct += 1 args = ConnCallback.RequestArgs(link_handle, name, requested_source, None, @@ -163,7 +169,7 @@ class ConnCallback(pyngus.ConnectionEventHandler): def receiver_requested(self, connection, link_handle, name, requested_target, properties): - _validate_callback(connection) + _validate_conn_callback(connection) self.receiver_requested_ct += 1 args = ConnCallback.RequestArgs(link_handle, name, None, requested_target, @@ -171,12 +177,13 @@ class ConnCallback(pyngus.ConnectionEventHandler): self.receiver_requested_args.append(args) def sasl_step(self, connection, pn_sasl): - _validate_callback(connection) + _validate_conn_callback(connection) self.sasl_step_ct += 1 def sasl_done(self, connection, pn_sasl, result): - _validate_callback(connection) + _validate_conn_callback(connection) self.sasl_done_ct += 1 + self.sasl_done_outcome = result class SenderCallback(pyngus.SenderEventHandler): @@ -188,20 +195,20 @@ class SenderCallback(pyngus.SenderEventHandler): self.credit_granted_ct = 0 def sender_active(self, sender_link): - _validate_callback(sender_link.connection) + _validate_link_callback(sender_link) self.active_ct += 1 def sender_remote_closed(self, sender_link, error=None): - _validate_callback(sender_link.connection) + _validate_link_callback(sender_link) self.remote_closed_ct += 1 self.remote_closed_error = error def sender_closed(self, sender_link): - _validate_callback(sender_link.connection) + _validate_link_callback(sender_link) self.closed_ct += 1 def credit_granted(self, sender_link): - _validate_callback(sender_link.connection) + _validate_link_callback(sender_link) self.credit_granted_ct += 1 @@ -215,7 +222,7 @@ class DeliveryCallback(object): self.count = 0 def __call__(self, link, handle, status, info): - _validate_callback(link.connection) + _validate_link_callback(link) self.link = link self.handle = handle self.status = status @@ -233,19 +240,23 @@ class ReceiverCallback(pyngus.ReceiverEventHandler): self.received_messages = [] def receiver_active(self, receiver_link): - _validate_callback(receiver_link.connection) + _validate_link_callback(receiver_link) + _validate_conn_callback(receiver_link.connection) self.active_ct += 1 def receiver_remote_closed(self, receiver_link, error=None): - _validate_callback(receiver_link.connection) + _validate_link_callback(receiver_link) + _validate_conn_callback(receiver_link.connection) self.remote_closed_ct += 1 self.remote_closed_error = error def receiver_closed(self, receiver_link): - _validate_callback(receiver_link.connection) + _validate_link_callback(receiver_link) + _validate_conn_callback(receiver_link.connection) self.closed_ct += 1 def message_received(self, receiver_link, message, handle): - _validate_callback(receiver_link.connection) + _validate_link_callback(receiver_link) + _validate_conn_callback(receiver_link.connection) self.message_received_ct += 1 self.received_messages.append((message, handle)) diff --git a/tests/unit_tests/connection.py b/tests/unit_tests/connection.py index 49ed1d1..8e43014 100644 --- a/tests/unit_tests/connection.py +++ b/tests/unit_tests/connection.py @@ -435,8 +435,8 @@ class APITest(common.Test): assert cb1.failed_ct > 0 assert cb1.failed_error - def test_process_reentrancy(self): - """Catch any attempt to re-enter Connection.process() from a + def test_non_reentrant_callback(self): + """Catch any attempt to call a non-reentrant Connection method from a callback.""" class BadCallback(common.ConnCallback): def connection_active(self, connection): @@ -896,5 +896,7 @@ mech_list: EXTERNAL DIGEST-MD5 SCRAM-SHA-1 CRAM-MD5 PLAIN ANONYMOUS assert not c1.active, c1.active assert c1_events.failed_ct == 1, c1_events.failed_ct assert not c2.active, c2.active - assert c2_events.sasl_done_ct == 0, c2_events.sasl_done_ct + assert c2_events.sasl_done_ct == 1, c2_events.sasl_done_ct + # outcome 1 == auth error + assert c2_events.sasl_done_outcome == 1, c2_events.sasl_done_outcome assert c2_events.failed_ct == 1, c2_events.failed_ct diff --git a/tests/unit_tests/link.py b/tests/unit_tests/link.py index 90d8a4b..4cd9021 100644 --- a/tests/unit_tests/link.py +++ b/tests/unit_tests/link.py @@ -704,3 +704,63 @@ class APITest(common.Test): assert mode and mode == 'unsettled' mode = args.properties.get('rcv-settle-mode') assert mode and mode == 'second' + + def test_non_reentrant_callback(self): + class SenderBadCallback(common.SenderCallback): + def sender_active(self, sender): + # Illegal: + sender.destroy() + + class ReceiverBadCallback(common.ReceiverCallback): + def receiver_active(self, receiver): + # Illegal: + receiver.destroy() + + class ReceiverBadCallback2(common.ReceiverCallback): + def receiver_active(self, receiver): + # Illegal: + receiver.connection.destroy() + + sc = SenderBadCallback() + sender = self.conn1.create_sender("src1", "tgt1", + event_handler=sc) + sender.open() + self.process_connections() + assert self.conn2_handler.receiver_requested_ct + args = self.conn2_handler.receiver_requested_args[-1] + receiver = self.conn2.accept_receiver(args.link_handle) + receiver.open() + try: + self.process_connections() + assert False, "RuntimeError expected!" + except RuntimeError: + pass + + sender = self.conn1.create_sender("src2", "tgt2") + sender.open() + self.process_connections() + args = self.conn2_handler.receiver_requested_args[-1] + rc = ReceiverBadCallback() + receiver = self.conn2.accept_receiver(args.link_handle, + event_handler=rc) + receiver.open() + try: + self.process_connections() + assert False, "RuntimeError expected!" + except RuntimeError: + pass + + sender = self.conn1.create_sender("src3", "tgt3") + sender.open() + self.process_connections() + args = self.conn2_handler.receiver_requested_args[-1] + rc = ReceiverBadCallback2() + receiver = self.conn2.accept_receiver(args.link_handle, + event_handler=rc) + receiver.open() + try: + self.process_connections() + assert False, "RuntimeError expected!" + except RuntimeError: + pass +