Fix reentrancy issues

This commit is contained in:
Kenneth Giusti
2015-08-24 09:58:17 -04:00
parent 2492a6a00d
commit c5bfa24e29
13 changed files with 562 additions and 244 deletions

View File

@@ -33,6 +33,42 @@ LOG = logging.getLogger()
LOG.addHandler(logging.StreamHandler())
class ConnectionEventHandler(pyngus.ConnectionEventHandler):
def connection_failed(self, connection, error):
"""Connection has failed in some way."""
LOG.warn("Connection failed: %s", error)
connection.close()
def connection_remote_closed(self, connection, pn_condition):
"""Peer has closed its end of the connection."""
LOG.debug("connection_remote_closed condition=%s", pn_condition)
connection.close()
class ReceiverEventHandler(pyngus.ReceiverEventHandler):
def __init__(self):
self.done = False
self.message = None
self.handle = None
def receiver_remote_closed(self, receiver_link, pn_condition):
"""Peer has closed its end of the link."""
LOG.debug("receiver_remote_closed condition=%s", pn_condition)
receiver_link.close()
self.done = True
def receiver_failed(self, receiver_link, error):
"""Protocol error occurred."""
LOG.warn("receiver_failed error=%s", error)
receiver_link.close()
self.done = True
def message_received(self, receiver, message, handle):
self.done = True
self.message = message
self.handle = handle
def main(argv=None):
_usage = """Usage: %prog [options]"""
@@ -53,6 +89,12 @@ def main(argv=None):
help="enable protocol tracing")
parser.add_option("--ca",
help="Certificate Authority PEM file")
parser.add_option("--username", type="string",
help="User Id for authentication")
parser.add_option("--password", type="string",
help="User password for authentication")
parser.add_option("--sasl-mechs", type="string",
help="The list of acceptable SASL mechs")
opts, extra = parser.parse_args(args=argv)
if opts.debug:
@@ -64,35 +106,28 @@ def main(argv=None):
#
container = pyngus.Container(uuid.uuid4().hex)
conn_properties = {'hostname': host,
'x-server': False,
'x-username': 'guest',
'x-password': 'guest',
'x-sasl-mechs': "ANONYMOUS PLAIN"}
'x-server': False}
if opts.trace:
conn_properties["x-trace-protocol"] = True
if opts.ca:
conn_properties["x-ssl-ca-file"] = opts.ca
if opts.idle_timeout:
conn_properties["idle-time-out"] = opts.idle_timeout
if opts.username:
conn_properties['x-username'] = opts.username
if opts.password:
conn_properties['x-password'] = opts.password
if opts.sasl_mechs:
conn_properties['x-sasl-mechs'] = opts.sasl_mechs
c_handler = pyngus.ConnectionEventHandler()
connection = container.create_connection("receiver",
None, # no events
c_handler,
conn_properties)
connection.open()
class ReceiveCallback(pyngus.ReceiverEventHandler):
def __init__(self):
self.done = False
self.message = None
self.handle = None
def message_received(self, receiver, message, handle):
self.done = True
self.message = message
self.handle = handle
target_address = opts.target_addr or uuid.uuid4().hex
cb = ReceiveCallback()
cb = ReceiverEventHandler()
receiver = connection.create_receiver(target_address,
opts.source_addr,
cb)
@@ -104,8 +139,10 @@ def main(argv=None):
process_connection(connection, my_socket)
if cb.done:
print("Receive done, message=%s" % str(cb.message))
receiver.message_accepted(cb.handle)
print("Receive done, message=%s" % str(cb.message) if cb.message
else "ERROR: no message received")
if cb.handle:
receiver.message_accepted(cb.handle)
else:
print("Receive failed due to connection failure!")

View File

@@ -93,12 +93,23 @@ class MyConnection(pyngus.ConnectionEventHandler):
LOG.debug("select() returned")
if readable:
pyngus.read_socket_input(self.connection,
self.socket)
try:
pyngus.read_socket_input(self.connection,
self.socket)
except Exception as e:
LOG.error("Exception on socket read: %s", str(e))
self.connection.close_input()
self.connection.close()
self.connection.process(time.time())
if writable:
pyngus.write_socket_output(self.connection,
self.socket)
try:
pyngus.write_socket_output(self.connection,
self.socket)
except Exception as e:
LOG.error("Exception on socket write: %s", str(e))
self.connection.close_output()
self.connection.close()
def close(self, error=None):
self.connection.close(error)
@@ -131,6 +142,7 @@ class MyConnection(pyngus.ConnectionEventHandler):
def connection_remote_closed(self, connection, reason):
LOG.debug("Connection remote closed callback")
connection.close()
def connection_closed(self, connection):
LOG.debug("Connection closed callback")

View File

@@ -53,6 +53,9 @@ LOG.addHandler(logging.StreamHandler())
sender_links = {}
receiver_links = {}
# links that have closed and need to be destroyed:
dead_links = set()
# Map reply-to address to the proper sending link (indexed by address)
reply_senders = {}
@@ -92,10 +95,8 @@ class SocketConnection(pyngus.ConnectionEventHandler):
pyngus.read_socket_input(self.connection, self.socket)
except Exception as e:
LOG.error("Exception on socket read: %s", str(e))
# may be redundant if closed cleanly:
self.connection_closed(self.connection)
return
self.connection.close_input()
self.connection.close()
self.connection.process(time.time())
def send_output(self):
@@ -105,10 +106,8 @@ class SocketConnection(pyngus.ConnectionEventHandler):
self.socket)
except Exception as e:
LOG.error("Exception on socket write: %s", str(e))
# may be redundant if closed cleanly:
self.connection_closed(self.connection)
return
self.connection.close_output()
self.connection.close()
self.connection.process(time.time())
# ConnectionEventHandler callbacks:
@@ -198,6 +197,12 @@ class MySenderLink(pyngus.SenderEventHandler):
properties)
self.sender_link.open()
@property
def closed(self):
if self.sender_link:
return self.sender_link.closed
return True
# SenderEventHandler callbacks:
def sender_active(self, sender_link):
@@ -211,12 +216,14 @@ class MySenderLink(pyngus.SenderEventHandler):
LOG.debug("sender closed callback")
global sender_links
global reply_senders
global dead_links
if self._ident in sender_links:
del sender_links[self._ident]
if self._source_address in reply_senders:
del reply_senders[self._source_address]
self.sender_link.destroy()
dead_links.add(self.sender_link)
self.sender_link = None
# 'message sent' callback:
@@ -238,6 +245,12 @@ class MyReceiverLink(pyngus.ReceiverEventHandler):
properties)
self._link.open()
@property
def closed(self):
if self._link:
return self._link.closed
return True
# ReceiverEventHandler callbacks:
def receiver_active(self, receiver_link):
LOG.debug("receiver active callback")
@@ -250,10 +263,12 @@ class MyReceiverLink(pyngus.ReceiverEventHandler):
def receiver_closed(self, receiver_link):
LOG.debug("receiver closed callback")
global receiver_links
global dead_links
if self._ident in receiver_links:
del receiver_links[self._ident]
self._link.destroy()
dead_links.add(self._link)
self._link = None
def message_received(self, receiver_link, message, handle):
@@ -353,6 +368,7 @@ def main(argv=None):
#
container = pyngus.Container("example RPC service")
global socket_connections
global dead_links
while True:
@@ -428,7 +444,11 @@ def main(argv=None):
w.send_output()
worked.append(w)
# nuke any completed connections:
# first, free any closed links:
while dead_links:
dead_links.pop().destroy()
# then nuke any completed connections:
closed = False
while worked:
sc = worked.pop()

View File

@@ -35,6 +35,29 @@ LOG = logging.getLogger()
LOG.addHandler(logging.StreamHandler())
class ConnectionEventHandler(pyngus.ConnectionEventHandler):
def connection_failed(self, connection, error):
"""Connection's transport has failed in some way."""
LOG.warn("Connection failed: %s", error)
connection.close()
def connection_remote_closed(self, connection, pn_condition):
"""Peer has closed its end of the connection."""
LOG.debug("connection_remote_closed condition=%s", pn_condition)
connection.close()
class SenderEventHandler(pyngus.SenderEventHandler):
def sender_remote_closed(self, sender_link, pn_condition):
LOG.debug("Sender peer_closed condition=%s", pn_condition)
sender_link.close()
def sender_failed(self, sender_link, error):
"""Protocol error occurred."""
LOG.debug("Sender failed error=%s", error)
sender_link.close()
def main(argv=None):
_usage = """Usage: %prog [options] [message content string]"""
@@ -55,6 +78,12 @@ def main(argv=None):
help="enable protocol tracing")
parser.add_option("--ca",
help="Certificate Authority PEM file")
parser.add_option("--username", type="string",
help="User Id for authentication")
parser.add_option("--password", type="string",
help="User password for authentication")
parser.add_option("--sasl-mechs", type="string",
help="The list of acceptable SASL mechs")
opts, payload = parser.parse_args(args=argv)
if not payload:
@@ -69,23 +98,31 @@ def main(argv=None):
#
container = pyngus.Container(uuid.uuid4().hex)
conn_properties = {'hostname': host,
'x-server': False,
'x-sasl-mechs': "ANONYMOUS PLAIN"}
'x-server': False}
if opts.trace:
conn_properties["x-trace-protocol"] = True
if opts.ca:
conn_properties["x-ssl-ca-file"] = opts.ca
if opts.idle_timeout:
conn_properties["idle-time-out"] = opts.idle_timeout
if opts.username:
conn_properties['x-username'] = opts.username
if opts.password:
conn_properties['x-password'] = opts.password
if opts.sasl_mechs:
conn_properties['x-sasl-mechs'] = opts.sasl_mechs
c_handler = ConnectionEventHandler()
connection = container.create_connection("sender",
None, # no events
c_handler,
conn_properties)
connection.open()
source_address = opts.source_addr or uuid.uuid4().hex
s_handler = SenderEventHandler()
sender = connection.create_sender(source_address,
opts.target_addr)
opts.target_addr,
s_handler)
sender.open()
# Send a single message:

View File

@@ -47,13 +47,11 @@ class SocketConnection(pyngus.ConnectionEventHandler):
properties)
self.connection.user_context = self
self.connection.open()
self.closed = False
self.sender_links = set()
self.receiver_links = set()
def destroy(self):
self.closed = True
for link in self.sender_links.copy():
link.destroy()
for link in self.receiver_links.copy():
@@ -65,6 +63,10 @@ class SocketConnection(pyngus.ConnectionEventHandler):
self.socket.close()
self.socket = None
@property
def closed(self):
return self.connection == None or self.connection.closed
def fileno(self):
"""Allows use of a SocketConnection in a select() call."""
return self.socket.fileno()
@@ -75,9 +77,8 @@ class SocketConnection(pyngus.ConnectionEventHandler):
pyngus.read_socket_input(self.connection, self.socket)
except Exception as e:
LOG.error("Exception on socket read: %s", str(e))
# may be redundant if closed cleanly:
self.connection_closed(self.connection)
return
self.connection.close_input()
self.connection.close()
self.connection.process(time.time())
def send_output(self):
@@ -87,9 +88,8 @@ class SocketConnection(pyngus.ConnectionEventHandler):
self.socket)
except Exception as e:
LOG.error("Exception on socket write: %s", str(e))
# may be redundant if closed cleanly:
self.connection_closed(self.connection)
return
self.connection.close_output()
self.connection.close()
self.connection.process(time.time())
# ConnectionEventHandler callbacks:
@@ -102,13 +102,10 @@ class SocketConnection(pyngus.ConnectionEventHandler):
def connection_closed(self, connection):
LOG.debug("Connection: closed.")
# main loop will destroy
self.closed = True
def connection_failed(self, connection, error):
LOG.error("Connection: failed! error=%s", str(error))
# No special recovery - just simulate close completed:
self.connection_closed(connection)
self.connection.close()
def sender_requested(self, connection, link_handle,
name, requested_source, properties):
@@ -152,6 +149,10 @@ class MySenderLink(pyngus.SenderEventHandler):
self.sender_link.open()
print("New Sender link created, name=%s" % sl.name)
@property
def closed(self):
return self.sender_link.closed
def destroy(self):
print("Sender link destroyed, name=%s" % self.sender_link.name)
self.socket_conn.sender_links.discard(self)
@@ -176,11 +177,6 @@ class MySenderLink(pyngus.SenderEventHandler):
LOG.debug("Sender: Remote closed")
self.sender_link.close()
def sender_closed(self, sender_link):
LOG.debug("Sender: Closed")
# Done with this sender:
self.destroy()
def credit_granted(self, sender_link):
LOG.debug("Sender: credit granted")
# Send a single message:
@@ -208,6 +204,10 @@ class MyReceiverLink(pyngus.ReceiverEventHandler):
self.receiver_link.add_capacity(1)
print("New Receiver link created, name=%s" % rl.name)
@property
def closed(self):
return self.receiver_link.closed
def destroy(self):
print("Receiver link destroyed, name=%s" % self.receiver_link.name)
self.socket_conn.receiver_links.discard(self)
@@ -224,11 +224,6 @@ class MyReceiverLink(pyngus.ReceiverEventHandler):
LOG.debug("Receiver: Remote closed")
self.receiver_link.close()
def receiver_closed(self, receiver_link):
LOG.debug("Receiver: Closed")
# Done with this Receiver:
self.destroy()
def message_received(self, receiver_link, message, handle):
self.receiver_link.message_accepted(handle)
print("Message received on Receiver link %s, message=%s"
@@ -256,6 +251,14 @@ def main(argv=None):
help="PEM File containing the server's private key")
parser.add_option("--keypass",
help="Password used to decrypt key file")
parser.add_option("--require-auth", action="store_true",
help="Require clients to authenticate")
parser.add_option("--sasl-mechs", type="string",
help="The list of acceptable SASL mechs")
parser.add_option("--sasl-cfg-name", type="string",
help="name of SASL config file (no suffix)")
parser.add_option("--sasl-cfg-dir", type="string",
help="Path to the SASL config file")
opts, arguments = parser.parse_args(args=argv)
if opts.debug:
@@ -300,9 +303,15 @@ def main(argv=None):
client_socket, client_address = my_socket.accept()
# name = uuid.uuid4().hex
name = str(client_address)
conn_properties = {'x-server': True,
'x-require-auth': False,
'x-sasl-mechs': "ANONYMOUS"}
conn_properties = {'x-server': True}
if opts.require_auth:
conn_properties['x-require-auth'] = True
if opts.sasl_mechs:
conn_properties['x-sasl-mechs'] = opts.sasl_mechs
if opts.sasl_cfg_name:
conn_properties['x-sasl-config-name'] = opts.sasl_cfg_name
if opts.sasl_cfg_dir:
conn_properties['x-sasl-config-dir'] = opts.sasl_cfg_dir
if opts.idle_timeout:
conn_properties["idle-time-out"] = opts.idle_timeout
if opts.trace:
@@ -337,14 +346,20 @@ def main(argv=None):
w.send_output()
worked.add(w)
# nuke any completed connections:
closed = False
while worked:
sc = worked.pop()
# nuke any completed connections:
if sc.closed:
socket_connections.discard(sc)
sc.destroy()
closed = True
else:
# can free any closed links now (optional):
for link in sc.sender_links | sc.receiver_links:
if link.closed:
link.destroy()
if closed:
LOG.debug("%d active connections present", len(socket_connections))

View File

@@ -19,6 +19,7 @@
"""Utilities used by the Examples"""
import errno
import logging
import re
import socket
import select
@@ -26,6 +27,7 @@ import time
import pyngus
LOG = logging.getLogger()
def get_host_port(server_address):
"""Parse the hostname and port out of the server_address."""
@@ -103,10 +105,24 @@ def process_connection(connection, my_socket):
[],
timeout)
if readable:
pyngus.read_socket_input(connection, my_socket)
try:
pyngus.read_socket_input(connection, my_socket)
except Exception as e:
# treat any socket error as
LOG.error("Socket error on read: %s", str(e))
connection.close_input()
# make an attempt to cleanly close
connection.close()
connection.process(time.time())
if writable:
pyngus.write_socket_output(connection, my_socket)
try:
pyngus.write_socket_output(connection, my_socket)
except Exception as e:
LOG.error("Socket error on write %s", str(e))
connection.close_output()
# this may not help, but it won't hurt:
connection.close()
return True
# Map the send callback status to a string

View File

@@ -35,6 +35,24 @@ LOG = logging.getLogger(__name__)
_PROTON_VERSION = (int(getattr(proton, "VERSION_MAJOR", 0)),
int(getattr(proton, "VERSION_MINOR", 0)))
class _CallbackLock(object):
"""A utility class for detecting when a callback invokes a non-reentrant
Pyngus method.
"""
def __init__(self):
super(_CallbackLock, self).__init__()
self.in_callback = 0
def __enter__(self):
self.in_callback += 1
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.in_callback -= 1
# if a call is made to a non-reentrant method while this context is
# held, then the method will raise a RuntimeError(). Return false to
# propagate the exception to the caller
return False
class ConnectionEventHandler(object):
"""An implementation of an AMQP 1.0 Connection."""
@@ -89,6 +107,17 @@ class Connection(Endpoint):
'x-sasl-mechs', 'x-sasl-config-dir',
'x-sasl-config-name'])
def _not_reentrant(func):
"""Decorator that prevents callbacks from calling into methods that are
not reentrant
"""
def wrap(self, *args, **kws):
if self._callback_lock.in_callback:
m = "Connection %s cannot be invoked from a callback!" % func
raise RuntimeError(m)
return func(self, *args, **kws)
return wrap
def __init__(self, container, name, event_handler=None, properties=None):
"""Create a new connection from the Container
@@ -213,8 +242,8 @@ class Connection(Endpoint):
self._error = None
self._next_deadline = 0
self._user_context = None
self._in_process = False
self._remote_session_id = 0
self._callback_lock = _CallbackLock()
self._pn_sasl = None
self._sasl_done = False
@@ -242,19 +271,21 @@ class Connection(Endpoint):
else:
# new Proton SASL configuration
if 'x-require-auth' in self._properties:
self._init_sasl()
ra = self._properties['x-require-auth']
self._pn_transport.require_auth(ra)
if 'x-username' in self._properties:
self._init_sasl()
# maintain old behavior: allow PLAIN and ANONYMOUS
# authentication. Override this using x-sasl-mechs below:
self.pn_sasl.allow_insecure_mechs = True
self._pn_connection.user = self._properties['x-username']
if 'x-password' in self._properties:
self._init_sasl()
self._pn_connection.password = \
self._properties['x-password']
if 'x-sasl-mechs' in self._properties:
self.pn_sasl.allowed_mechs(
self._properties['x-sasl-mechs'])
mechs = self._properties['x-sasl-mechs'].upper()
self.pn_sasl.allowed_mechs(mechs)
if 'PLAIN' not in mechs and 'ANONYMOUS' not in mechs:
self.pn_sasl.allow_insecure_mechs = False
if 'x-sasl-config-dir' in self._properties:
self.pn_sasl.config_path(
self._properties['x-sasl-config-dir'])
@@ -309,13 +340,10 @@ class Connection(Endpoint):
return self._pn_connection.remote_properties
return None
def _init_sasl(self):
if not self._pn_sasl:
self._pn_sasl = self._pn_transport.sasl()
@property
def pn_sasl(self):
self._init_sasl()
if not self._pn_sasl:
self._pn_sasl = self._pn_transport.sasl()
return self._pn_sasl
def pn_ssl(self):
@@ -359,8 +387,13 @@ class Connection(Endpoint):
"""Return True if the Connection has finished closing."""
return (self._write_done and self._read_done)
@_not_reentrant
def destroy(self):
self._error = "Destroyed"
# if a connection is destroyed without flushing pending output,
# the remote will see an unclean shutdown (framing error)
if self.has_output > 0:
LOG.debug("Connection with buffered output destroyed")
self._error = "Destroyed by the application"
self._handler = None
self._sender_links.clear()
self._receiver_links.clear()
@@ -382,79 +415,80 @@ class Connection(Endpoint):
_CLOSED = (proton.Endpoint.LOCAL_CLOSED | proton.Endpoint.REMOTE_CLOSED)
_ACTIVE = (proton.Endpoint.LOCAL_ACTIVE | proton.Endpoint.REMOTE_ACTIVE)
@_not_reentrant
def process(self, now):
"""Perform connection state processing."""
if self._in_process:
raise RuntimeError("Connection.process() is not re-entrant!")
self._in_process = True
try:
# if the connection has hit an unrecoverable error,
# nag the application until connection is destroyed
if self._error:
if self._handler:
self._handler.connection_failed(self, self._error)
# nag application until connection is destroyed
self._next_deadline = now
return now
if self._pn_connection is None:
LOG.error("Connection.process() called on destroyed connection!")
return 0
# do nothing until the connection has been opened
if self._pn_connection.state & proton.Endpoint.LOCAL_UNINIT:
return 0
# do nothing until the connection has been opened
if self._pn_connection.state & proton.Endpoint.LOCAL_UNINIT:
return 0
if self._pn_sasl and not self._sasl_done:
# wait until SASL has authenticated
if (_PROTON_VERSION < (0, 10)):
# wait until SASL has authenticated
if self._pn_sasl and not self._sasl_done:
if self._pn_sasl.state not in (proton.SASL.STATE_PASS,
proton.SASL.STATE_FAIL):
LOG.debug("SASL in progress. State=%s",
str(self._pn_sasl.state))
if self._handler:
self._handler.sasl_step(self, self._pn_sasl)
return self._next_deadline
self._sasl_done = True
if self._pn_sasl.state not in (proton.SASL.STATE_PASS,
proton.SASL.STATE_FAIL):
LOG.debug("SASL in progress. State=%s",
str(self._pn_sasl.state))
if self._handler:
with self._callback_lock:
self._handler.sasl_step(self, self._pn_sasl)
return self._next_deadline
self._sasl_done = True
if self._handler:
with self._callback_lock:
self._handler.sasl_done(self, self._pn_sasl,
self._pn_sasl.outcome)
# process timer events:
timer_deadline = self._expire_timers(now)
transport_deadline = self._pn_transport.tick(now)
if timer_deadline and transport_deadline:
self._next_deadline = min(timer_deadline, transport_deadline)
else:
self._next_deadline = timer_deadline or transport_deadline
if self._pn_sasl.outcome is not None:
self._sasl_done = True
if self._handler:
with self._callback_lock:
self._handler.sasl_done(self, self._pn_sasl,
self._pn_sasl.outcome)
# process events from proton:
# process timer events:
timer_deadline = self._expire_timers(now)
transport_deadline = self._pn_transport.tick(now)
if timer_deadline and transport_deadline:
self._next_deadline = min(timer_deadline, transport_deadline)
else:
self._next_deadline = timer_deadline or transport_deadline
# process events from proton:
pn_event = self._pn_collector.peek()
while pn_event:
LOG.debug("pn_event: %s received", pn_event.type)
if self._handle_proton_event(pn_event):
pass
elif _SessionProxy._handle_proton_event(pn_event, self):
pass
elif _Link._handle_proton_event(pn_event, self):
pass
self._pn_collector.pop()
pn_event = self._pn_collector.peek()
while pn_event:
if self._handle_proton_event(pn_event):
pass
elif _SessionProxy._handle_proton_event(pn_event, self):
pass
elif _Link._handle_proton_event(pn_event, self):
pass
self._pn_collector.pop()
pn_event = self._pn_collector.peek()
# re-check for connection failure after processing all pending
# engine events:
if self._error:
if self._handler:
# check for connection failure after processing all pending
# engine events:
if self._error:
if self._handler:
# nag application until connection is destroyed
self._next_deadline = now
with self._callback_lock:
self._handler.connection_failed(self, self._error)
# nag application until connection is destroyed
self._next_deadline = now
elif (self._endpoint_state == self._CLOSED and
self._read_done and self._write_done):
# invoke closed callback after endpoint has fully closed and
# all pending I/O has completed:
if self._handler:
elif (self._endpoint_state == self._CLOSED and
self._read_done and self._write_done):
# invoke closed callback after endpoint has fully closed and
# all pending I/O has completed:
if self._handler:
with self._callback_lock:
self._handler.connection_closed(self)
return self._next_deadline
finally:
self._in_process = False
return self._next_deadline
@property
def next_tick(self):
@@ -470,14 +504,17 @@ class Connection(Endpoint):
@property
def needs_input(self):
if self._read_done:
LOG.debug("needs_input EOS")
return self.EOS
try:
# TODO(grs): can this actually throw?
capacity = self._pn_transport.capacity()
except Exception as e:
return self._connection_failed(str(e))
self._read_done = True
self._connection_failed(str(e))
return self.EOS
if capacity >= 0:
return capacity
LOG.debug("needs_input read done")
self._read_done = True
return self.EOS
@@ -486,10 +523,14 @@ class Connection(Endpoint):
if c <= 0:
return c
try:
LOG.debug("pushing %s bytes to transport:", c)
rc = self._pn_transport.push(in_data[:c])
except Exception as e:
return self._connection_failed(str(e))
self._read_done = True
self._connection_failed(str(e))
return self.EOS
if rc: # error?
LOG.debug("process_input read done")
self._read_done = True
return self.EOS
# hack: check if this was the last input needed by the connection.
@@ -504,18 +545,23 @@ class Connection(Endpoint):
self._pn_transport.close_tail()
except Exception as e:
self._connection_failed(str(e))
LOG.debug("close_input read done")
self._read_done = True
@property
def has_output(self):
if self._write_done:
LOG.debug("has output EOS")
return self.EOS
try:
pending = self._pn_transport.pending()
except Exception as e:
return self._connection_failed(str(e))
self._write_done = True
self._connection_failed(str(e))
return self.EOS
if pending >= 0:
return pending
LOG.debug("has output write_done")
self._write_done = True
return self.EOS
@@ -526,6 +572,7 @@ class Connection(Endpoint):
if c <= 0:
return None
try:
LOG.debug("Getting %s bytes output from transport", c)
buf = self._pn_transport.peek(c)
except Exception as e:
self._connection_failed(str(e))
@@ -534,9 +581,11 @@ class Connection(Endpoint):
def output_written(self, count):
try:
LOG.debug("Popping %s bytes output from transport", count)
self._pn_transport.pop(count)
except Exception as e:
return self._connection_failed(str(e))
self._write_done = True
self._connection_failed(str(e))
# hack: check if this was the last output from the connection. If so,
# this will set the _write_done flag and the 'connection closed'
# callback can be issued on the next call to process()
@@ -548,6 +597,7 @@ class Connection(Endpoint):
self._pn_transport.close_head()
except Exception as e:
self._connection_failed(str(e))
LOG.debug("close output write done")
self._write_done = True
def create_sender(self, source_address, target_address=None,
@@ -584,6 +634,9 @@ class Connection(Endpoint):
if not link:
raise Exception("Invalid link_handle: %s" % link_handle)
link.reject(pn_condition)
# note: normally, link.destroy() cannot be called from a callback,
# but this link was never made available to the application so this
# link is only referenced by the connection
link.destroy()
def create_receiver(self, target_address, source_address=None,
@@ -619,6 +672,9 @@ class Connection(Endpoint):
if not link:
raise Exception("Invalid link_handle: %s" % link_handle)
link.reject(pn_condition)
# note: normally, link.destroy() cannot be called from a callback,
# but this link was never made available to the application so this
# link is only referenced by the connection
link.destroy()
@property
@@ -637,12 +693,7 @@ class Connection(Endpoint):
"""Clean up after connection failure detected."""
if not self._error:
LOG.error("Connection failed: %s", str(error))
self._read_done = True
self._write_done = True
self._error = error
# report error during the next call to process()
self._next_deadline = time.time()
return self.EOS
def _configure_ssl(self, properties):
if not properties:
@@ -759,20 +810,16 @@ class Connection(Endpoint):
"""Both ends of the Endpoint have become active."""
LOG.debug("Connection is up")
if self._handler:
if (_PROTON_VERSION >= (0, 10)):
# simulate the old sasl_done callback
if self._pn_sasl and not self._sasl_done:
self._sasl_done = True
self._handler.sasl_done(self, self._pn_sasl,
self._pn_sasl.outcome)
self._handler.connection_active(self)
with self._callback_lock:
self._handler.connection_active(self)
def _ep_need_close(self):
"""The remote has closed its end of the endpoint."""
LOG.debug("Connection remotely closed")
if self._handler:
cond = self._pn_connection.remote_condition
self._handler.connection_remote_closed(self, cond)
with self._callback_lock:
self._handler.connection_remote_closed(self, cond)
def _ep_error(self, error):
"""The endpoint state machine failed due to protocol error."""

View File

@@ -42,6 +42,40 @@ _snd_settle_modes = {"settled": proton.Link.SND_SETTLED,
_rcv_settle_modes = {"first": proton.Link.RCV_FIRST,
"second": proton.Link.RCV_SECOND}
#TODO(kgiusti): this is duplicated in connection.py, put in common file
class _CallbackLock(object):
"""A utility class for detecting when a callback invokes a non-reentrant
Pyngus method.
"""
def __init__(self, link):
super(_CallbackLock, self).__init__()
self._link = link
self.in_callback = 0
def __enter__(self):
# manually lock parent - can't enter its non-reentrant methods
self._link._connection._callback_lock.__enter__()
self.in_callback += 1
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.in_callback -= 1
self._link._connection._callback_lock.__exit__(None, None, None)
# if a call is made to a non-reentrant method while this context is
# held, then the method will raise a RuntimeError(). Return false to
# propagate the exception to the caller
return False
def _not_reentrant(func):
"""Decorator that prevents callbacks from calling into link methods that
are not reentrant """
def wrap(*args, **kws):
link = args[0]
if link._callback_lock.in_callback:
m = "Link %s cannot be invoked from a callback!" % func
raise RuntimeError(m)
return func(*args, **kws)
return wrap
class _Link(Endpoint):
"""A generic Link base class."""
@@ -54,6 +88,7 @@ class _Link(Endpoint):
self._user_context = None
self._rejected = False # requested link was refused
self._failed = False # protocol error occurred
self._callback_lock = _CallbackLock(self)
# TODO(kgiusti): raise jira to add 'context' attr to api
self._pn_link = pn_link
pn_link.context = self
@@ -213,10 +248,12 @@ class _Link(Endpoint):
if pn_event.type == proton.Event.DELIVERY:
pn_delivery = pn_event.context
pn_link = pn_delivery.link
pn_link.context._process_delivery(pn_delivery)
if pn_link.context:
pn_link.context._process_delivery(pn_delivery)
elif pn_event.type == proton.Event.LINK_FLOW:
pn_link = pn_event.context
pn_link.context._process_credit()
if pn_link.context:
pn_link.context._process_credit()
elif ep_event is not None:
pn_link = pn_event.context
if pn_link.context:
@@ -310,7 +347,8 @@ class SenderEventHandler(object):
LOG.debug("sender_active (ignored)")
def sender_remote_closed(self, sender_link, pn_condition):
LOG.debug("sender_remote_closed (ignored)")
LOG.debug("sender_remote_closed condition=%s (ignored)",
pn_condition)
def sender_closed(self, sender_link):
LOG.debug("sender_closed (ignored)")
@@ -320,7 +358,7 @@ class SenderEventHandler(object):
def sender_failed(self, sender_link, error):
"""Protocol error occurred."""
LOG.debug("sender_failed (ignored)")
LOG.debug("sender_failed error=%s (ignored)", error)
class SenderLink(_Link):
@@ -366,7 +404,8 @@ class SenderLink(_Link):
if self.tag in self.link._send_requests:
del self.link._send_requests[self.tag]
if self.callback:
self.callback(self.link, self.handle, state, info)
with self.link._callback_lock:
self.callback(self.link, self.handle, state, info)
def __init__(self, connection, pn_link):
super(SenderLink, self).__init__(connection, pn_link)
@@ -419,6 +458,7 @@ class SenderLink(_Link):
self._pn_link.source.type = proton.Terminus.UNSPECIFIED
super(SenderLink, self).reject(pn_condition)
@_not_reentrant
def destroy(self):
self._connection._remove_sender(self._name)
self._connection = None
@@ -476,7 +516,8 @@ class SenderLink(_Link):
if self._handler and not self._rejected:
if self._last_credit <= 0 and new_credit > 0:
LOG.debug("Credit is available, link=%s", self.name)
self._handler.credit_granted(self)
with self._callback_lock:
self._handler.credit_granted(self)
self._last_credit = new_credit
def _write_msg(self, pn_delivery, send_req):
@@ -501,20 +542,23 @@ class SenderLink(_Link):
def _link_failed(self, error):
if self._handler and not self._rejected:
self._handler.sender_failed(self, error)
with self._callback_lock:
self._handler.sender_failed(self, error)
# endpoint state machine actions:
def _ep_active(self):
LOG.debug("SenderLink is up")
if self._handler and not self._rejected:
self._handler.sender_active(self)
with self._callback_lock:
self._handler.sender_active(self)
def _ep_need_close(self):
LOG.debug("SenderLink remote closed")
if self._handler and not self._rejected:
cond = self._pn_link.remote_condition
self._handler.sender_remote_closed(self, cond)
with self._callback_lock:
self._handler.sender_remote_closed(self, cond)
def _ep_closed(self):
LOG.debug("SenderLink close completed")
@@ -526,7 +570,8 @@ class SenderLink(_Link):
key, send_req = self._send_requests.popitem()
send_req.destroy(SenderLink.ABORTED, info)
if self._handler and not self._rejected:
self._handler.sender_closed(self)
with self._callback_lock:
self._handler.sender_closed(self)
def _ep_requested(self):
LOG.debug("Remote has requested a SenderLink")
@@ -551,11 +596,12 @@ class SenderLink(_Link):
elif (dist_mode == proton.Terminus.DIST_MODE_MOVE):
props["distribution-mode"] = "move"
handler.sender_requested(self._connection,
pn_link.name, # handle
pn_link.name,
req_source,
props)
with self._connection._callback_lock:
handler.sender_requested(self._connection,
pn_link.name, # handle
pn_link.name,
req_source,
props)
class ReceiverEventHandler(object):
@@ -564,14 +610,15 @@ class ReceiverEventHandler(object):
LOG.debug("receiver_active (ignored)")
def receiver_remote_closed(self, receiver_link, pn_condition):
LOG.debug("receiver_remote_closed (ignored)")
LOG.debug("receiver_remote_closed condition=%s (ignored)",
pn_condition)
def receiver_closed(self, receiver_link):
LOG.debug("receiver_closed (ignored)")
def receiver_failed(self, receiver_link, error):
"""Protocol error occurred."""
LOG.debug("receiver_failed (ignored)")
LOG.debug("receiver_failed error=%s (ignored)", error)
def message_received(self, receiver_link, message, handle):
LOG.debug("message_received (ignored)")
@@ -631,6 +678,7 @@ class ReceiverLink(_Link):
self._pn_link.target.type = proton.Terminus.UNSPECIFIED
super(ReceiverLink, self).reject(pn_condition)
@_not_reentrant
def destroy(self):
self._connection._remove_receiver(self._name)
self._connection = None
@@ -651,7 +699,8 @@ class ReceiverLink(_Link):
handle = "rmsg-%s:%x" % (self._name, self._next_handle)
self._next_handle += 1
self._unsettled_deliveries[handle] = pn_delivery
self._handler.message_received(self, msg, handle)
with self._callback_lock:
self._handler.message_received(self, msg, handle)
else:
# TODO(kgiusti): is it ok to assume Delivery.REJECTED?
pn_delivery.settle()
@@ -662,25 +711,29 @@ class ReceiverLink(_Link):
def _link_failed(self, error):
if self._handler and not self._rejected:
self._handler.receiver_failed(self, error)
with self._callback_lock:
self._handler.receiver_failed(self, error)
# endpoint state machine actions:
def _ep_active(self):
LOG.debug("ReceiverLink is up")
if self._handler and not self._rejected:
self._handler.receiver_active(self)
with self._callback_lock:
self._handler.receiver_active(self)
def _ep_need_close(self):
LOG.debug("ReceiverLink remote closed")
if self._handler and not self._rejected:
cond = self._pn_link.remote_condition
self._handler.receiver_remote_closed(self, cond)
with self._callback_lock:
self._handler.receiver_remote_closed(self, cond)
def _ep_closed(self):
LOG.debug("ReceiverLink close completed")
if self._handler and not self._rejected:
self._handler.receiver_closed(self)
with self._callback_lock:
self._handler.receiver_closed(self)
def _ep_requested(self):
LOG.debug("Remote has initiated a ReceiverLink")
@@ -705,11 +758,12 @@ class ReceiverLink(_Link):
elif (dist_mode == proton.Terminus.DIST_MODE_MOVE):
props["distribution-mode"] = "move"
handler.receiver_requested(self._connection,
pn_link.name, # handle
pn_link.name,
req_target,
props)
with self._connection._callback_lock:
handler.receiver_requested(self._connection,
pn_link.name, # handle
pn_link.name,
req_target,
props)
class _SessionProxy(Endpoint):

View File

@@ -42,32 +42,35 @@ def read_socket_input(connection, socket_obj):
if count <= 0:
return count # 0 or EOS
try:
sock_data = socket_obj.recv(count)
except socket.timeout as e:
LOG.debug("Socket timeout exception %s", str(e))
raise # caller must handle
except socket.error as e:
LOG.debug("Socket error exception %s", str(e))
err = e.errno
# ignore non-fatal errors
if (err != errno.EAGAIN and
err != errno.EWOULDBLOCK and
err != errno.EINTR):
# otherwise, unrecoverable:
connection.close_input()
raise # caller must handle
except Exception as e: # beats me... assume fatal
LOG.debug("unknown socket exception %s", str(e))
connection.close_input()
raise # caller must handle
while True:
try:
sock_data = socket_obj.recv(count)
break
except socket.timeout as e:
LOG.debug("Socket timeout exception %s", str(e))
raise # caller must handle
except socket.error as e:
LOG.debug("Socket error exception %s", str(e))
err = e.errno
if err in [errno.EAGAIN,
errno.EWOULDBLOCK,
errno.EINTR]:
# try again later
return 0
# otherwise, unrecoverable, caller must handle
raise
except Exception as e: # beats me... assume fatal
LOG.debug("unknown socket exception %s", str(e))
raise # caller must handle
if sock_data:
if len(sock_data) > 0:
count = connection.process_input(sock_data)
else:
LOG.debug("Socket closed")
count = Connection.EOS
connection.close_input()
connection.close_output()
LOG.debug("Socket recv %s bytes", count)
return count
@@ -85,30 +88,34 @@ def write_socket_output(connection, socket_obj):
if not data:
# error - has_output > 0, but no data?
return Connection.EOS
try:
count = socket_obj.send(data)
except socket.timeout as e:
LOG.debug("Socket timeout exception %s", str(e))
raise # caller must handle
except socket.error as e:
LOG.debug("Socket error exception %s", str(e))
err = e.errno
# ignore non-fatal errors
if (err != errno.EAGAIN and
err != errno.EWOULDBLOCK and
err != errno.EINTR):
# otherwise, unrecoverable
connection.close_output()
raise
except Exception as e: # beats me... assume fatal
LOG.debug("unknown socket exception %s", str(e))
connection.close_output()
raise
while True:
try:
count = socket_obj.send(data)
break
except socket.timeout as e:
LOG.debug("Socket timeout exception %s", str(e))
raise # caller must handle
except socket.error as e:
LOG.debug("Socket error exception %s", str(e))
err = e.errno
if err in [errno.EAGAIN,
errno.EWOULDBLOCK,
errno.EINTR]:
# try again later
return 0
# else assume fatal let caller handle it:
raise
except Exception as e: # beats me... assume fatal
LOG.debug("unknown socket exception %s", str(e))
raise
if count > 0:
LOG.debug("Socket sent %s bytes", count)
connection.output_written(count)
elif data:
LOG.debug("Socket closed")
count = Connection.EOS
connection.close_output()
connection.close_input()
return count

View File

@@ -19,7 +19,7 @@
from setuptools import setup
_VERSION = "2.0.0a2" # NOTE: update __init__.py too!
_VERSION = "2.0.0a3" # NOTE: update __init__.py too!
# I hack, therefore I am (productive) Some distros (which will not be named)
# don't use setup.py to install the proton python module. In this case, pip

View File

@@ -102,9 +102,14 @@ def process_connections(c1, c2, timestamp=None):
c2.process(timestamp)
def _validate_callback(connection):
"""Callbacks must only occur from the Connection.process() call."""
assert connection._in_process
def _validate_conn_callback(connection):
"""Callbacks must only occur when holding the Connection callback lock"""
assert connection._callback_lock.in_callback, \
connection._callback_lock.in_callback
def _validate_link_callback(link):
"""Callbacks must only occur when holding the Link callback lock."""
assert link._callback_lock.in_callback, link._callback_lock.in_callback
class ConnCallback(pyngus.ConnectionEventHandler):
@@ -131,29 +136,30 @@ class ConnCallback(pyngus.ConnectionEventHandler):
self.sasl_step_ct = 0
self.sasl_done_ct = 0
self.sasl_done_outcome = None
def connection_active(self, connection):
_validate_callback(connection)
_validate_conn_callback(connection)
self.active_ct += 1
def connection_failed(self, connection, error):
_validate_callback(connection)
_validate_conn_callback(connection)
self.failed_ct += 1
self.failed_error = error
def connection_remote_closed(self, connection, error=None):
_validate_callback(connection)
_validate_conn_callback(connection)
self.remote_closed_ct += 1
self.remote_closed_error = error
def connection_closed(self, connection):
_validate_callback(connection)
_validate_conn_callback(connection)
self.closed_ct += 1
def sender_requested(self, connection, link_handle,
name, requested_source,
properties):
_validate_callback(connection)
_validate_conn_callback(connection)
self.sender_requested_ct += 1
args = ConnCallback.RequestArgs(link_handle, name,
requested_source, None,
@@ -163,7 +169,7 @@ class ConnCallback(pyngus.ConnectionEventHandler):
def receiver_requested(self, connection, link_handle,
name, requested_target,
properties):
_validate_callback(connection)
_validate_conn_callback(connection)
self.receiver_requested_ct += 1
args = ConnCallback.RequestArgs(link_handle, name,
None, requested_target,
@@ -171,12 +177,13 @@ class ConnCallback(pyngus.ConnectionEventHandler):
self.receiver_requested_args.append(args)
def sasl_step(self, connection, pn_sasl):
_validate_callback(connection)
_validate_conn_callback(connection)
self.sasl_step_ct += 1
def sasl_done(self, connection, pn_sasl, result):
_validate_callback(connection)
_validate_conn_callback(connection)
self.sasl_done_ct += 1
self.sasl_done_outcome = result
class SenderCallback(pyngus.SenderEventHandler):
@@ -188,20 +195,20 @@ class SenderCallback(pyngus.SenderEventHandler):
self.credit_granted_ct = 0
def sender_active(self, sender_link):
_validate_callback(sender_link.connection)
_validate_link_callback(sender_link)
self.active_ct += 1
def sender_remote_closed(self, sender_link, error=None):
_validate_callback(sender_link.connection)
_validate_link_callback(sender_link)
self.remote_closed_ct += 1
self.remote_closed_error = error
def sender_closed(self, sender_link):
_validate_callback(sender_link.connection)
_validate_link_callback(sender_link)
self.closed_ct += 1
def credit_granted(self, sender_link):
_validate_callback(sender_link.connection)
_validate_link_callback(sender_link)
self.credit_granted_ct += 1
@@ -215,7 +222,7 @@ class DeliveryCallback(object):
self.count = 0
def __call__(self, link, handle, status, info):
_validate_callback(link.connection)
_validate_link_callback(link)
self.link = link
self.handle = handle
self.status = status
@@ -233,19 +240,23 @@ class ReceiverCallback(pyngus.ReceiverEventHandler):
self.received_messages = []
def receiver_active(self, receiver_link):
_validate_callback(receiver_link.connection)
_validate_link_callback(receiver_link)
_validate_conn_callback(receiver_link.connection)
self.active_ct += 1
def receiver_remote_closed(self, receiver_link, error=None):
_validate_callback(receiver_link.connection)
_validate_link_callback(receiver_link)
_validate_conn_callback(receiver_link.connection)
self.remote_closed_ct += 1
self.remote_closed_error = error
def receiver_closed(self, receiver_link):
_validate_callback(receiver_link.connection)
_validate_link_callback(receiver_link)
_validate_conn_callback(receiver_link.connection)
self.closed_ct += 1
def message_received(self, receiver_link, message, handle):
_validate_callback(receiver_link.connection)
_validate_link_callback(receiver_link)
_validate_conn_callback(receiver_link.connection)
self.message_received_ct += 1
self.received_messages.append((message, handle))

View File

@@ -435,8 +435,8 @@ class APITest(common.Test):
assert cb1.failed_ct > 0
assert cb1.failed_error
def test_process_reentrancy(self):
"""Catch any attempt to re-enter Connection.process() from a
def test_non_reentrant_callback(self):
"""Catch any attempt to call a non-reentrant Connection method from a
callback."""
class BadCallback(common.ConnCallback):
def connection_active(self, connection):
@@ -896,5 +896,7 @@ mech_list: EXTERNAL DIGEST-MD5 SCRAM-SHA-1 CRAM-MD5 PLAIN ANONYMOUS
assert not c1.active, c1.active
assert c1_events.failed_ct == 1, c1_events.failed_ct
assert not c2.active, c2.active
assert c2_events.sasl_done_ct == 0, c2_events.sasl_done_ct
assert c2_events.sasl_done_ct == 1, c2_events.sasl_done_ct
# outcome 1 == auth error
assert c2_events.sasl_done_outcome == 1, c2_events.sasl_done_outcome
assert c2_events.failed_ct == 1, c2_events.failed_ct

View File

@@ -704,3 +704,63 @@ class APITest(common.Test):
assert mode and mode == 'unsettled'
mode = args.properties.get('rcv-settle-mode')
assert mode and mode == 'second'
def test_non_reentrant_callback(self):
class SenderBadCallback(common.SenderCallback):
def sender_active(self, sender):
# Illegal:
sender.destroy()
class ReceiverBadCallback(common.ReceiverCallback):
def receiver_active(self, receiver):
# Illegal:
receiver.destroy()
class ReceiverBadCallback2(common.ReceiverCallback):
def receiver_active(self, receiver):
# Illegal:
receiver.connection.destroy()
sc = SenderBadCallback()
sender = self.conn1.create_sender("src1", "tgt1",
event_handler=sc)
sender.open()
self.process_connections()
assert self.conn2_handler.receiver_requested_ct
args = self.conn2_handler.receiver_requested_args[-1]
receiver = self.conn2.accept_receiver(args.link_handle)
receiver.open()
try:
self.process_connections()
assert False, "RuntimeError expected!"
except RuntimeError:
pass
sender = self.conn1.create_sender("src2", "tgt2")
sender.open()
self.process_connections()
args = self.conn2_handler.receiver_requested_args[-1]
rc = ReceiverBadCallback()
receiver = self.conn2.accept_receiver(args.link_handle,
event_handler=rc)
receiver.open()
try:
self.process_connections()
assert False, "RuntimeError expected!"
except RuntimeError:
pass
sender = self.conn1.create_sender("src3", "tgt3")
sender.open()
self.process_connections()
args = self.conn2_handler.receiver_requested_args[-1]
rc = ReceiverBadCallback2()
receiver = self.conn2.accept_receiver(args.link_handle,
event_handler=rc)
receiver.open()
try:
self.process_connections()
assert False, "RuntimeError expected!"
except RuntimeError:
pass