Merge "Always set all socket timeouts"

This commit is contained in:
Jenkins 2016-03-23 07:13:28 +00:00 committed by Gerrit Code Review
commit 3802dd5e4b
2 changed files with 33 additions and 28 deletions

View File

@ -48,6 +48,9 @@ from oslo_messaging._i18n import _LW
from oslo_messaging import _utils from oslo_messaging import _utils
from oslo_messaging import exceptions from oslo_messaging import exceptions
# NOTE(sileht): don't exists in py2 socket module
TCP_USER_TIMEOUT = 18
rabbit_opts = [ rabbit_opts = [
cfg.StrOpt('kombu_ssl_version', cfg.StrOpt('kombu_ssl_version',
@ -647,6 +650,7 @@ class Connection(object):
# the kombu underlying connection works # the kombu underlying connection works
self._set_current_channel(None) self._set_current_channel(None)
self.ensure(method=lambda: self.connection.connection) self.ensure(method=lambda: self.connection.connection)
self.set_transport_socket_timeout()
def ensure(self, method, retry=None, def ensure(self, method, retry=None,
recoverable_error_callback=None, error_callback=None, recoverable_error_callback=None, error_callback=None,
@ -714,6 +718,8 @@ class Connection(object):
"""Callback invoked when the kombu reconnects and creates """Callback invoked when the kombu reconnects and creates
a new channel, we use it the reconfigure our consumers. a new channel, we use it the reconfigure our consumers.
""" """
self.set_transport_socket_timeout()
self._set_current_channel(new_channel) self._set_current_channel(new_channel)
for consumer in self._consumers: for consumer in self._consumers:
consumer.declare(self) consumer.declare(self)
@ -834,8 +840,7 @@ class Connection(object):
self._heartbeat_support_log_emitted = True self._heartbeat_support_log_emitted = True
return False return False
@contextlib.contextmanager def set_transport_socket_timeout(self, timeout=None):
def _transport_socket_timeout(self, timeout):
# NOTE(sileht): they are some case where the heartbeat check # NOTE(sileht): they are some case where the heartbeat check
# or the producer.send return only when the system socket # or the producer.send return only when the system socket
# timeout if reach. kombu doesn't allow use to customise this # timeout if reach. kombu doesn't allow use to customise this
@ -844,27 +849,37 @@ class Connection(object):
# kombu==3.0.33. Once the commit below is released, we should # kombu==3.0.33. Once the commit below is released, we should
# try to set the socket timeout in the constructor: # try to set the socket timeout in the constructor:
# https://github.com/celery/py-amqp/pull/64 # https://github.com/celery/py-amqp/pull/64
heartbeat_timeout = self.heartbeat_timeout_threshold
if self._heartbeat_supported_and_enabled():
# NOTE(sileht): we are supposed to send heartbeat every
# heartbeat_timeout, no need to wait more otherwise will
# disconnect us, so raise timeout earlier ourself
if timeout is None:
timeout = heartbeat_timeout
else:
timeout = min(heartbeat_timeout, timeout)
try: try:
sock = self.channel.connection.sock sock = self.channel.connection.sock
except AttributeError as e: except AttributeError as e:
# Level is set to debug because otherwise we would spam the logs # Level is set to debug because otherwise we would spam the logs
LOG.debug('Failed to get socket attribute: %s' % str(e)) LOG.debug('Failed to get socket attribute: %s' % str(e))
sock = None else:
if sock:
orig_timeout = sock.gettimeout()
sock.settimeout(timeout) sock.settimeout(timeout)
sock.setsockopt(socket.IPPROTO_TCP, TCP_USER_TIMEOUT,
timeout * 1000 if timeout is not None else 0)
@contextlib.contextmanager
def _transport_socket_timeout(self, timeout):
self.set_transport_socket_timeout(timeout)
yield yield
if sock: self.set_transport_socket_timeout()
sock.settimeout(orig_timeout)
def _heartbeat_check(self): def _heartbeat_check(self):
# NOTE(sileht): we are supposed to send at least one heartbeat # NOTE(sileht): we are supposed to send at least one heartbeat
# every heartbeat_timeout_threshold, so no need to way more # every heartbeat_timeout_threshold, so no need to way more
with self._transport_socket_timeout( self.connection.heartbeat_check(rate=self.heartbeat_rate)
self.heartbeat_timeout_threshold):
self.connection.heartbeat_check(
rate=self.heartbeat_rate)
def _heartbeat_start(self): def _heartbeat_start(self):
if self._heartbeat_supported_and_enabled(): if self._heartbeat_supported_and_enabled():
@ -1089,25 +1104,15 @@ class Connection(object):
auto_declare=not exchange.passive, auto_declare=not exchange.passive,
routing_key=routing_key) routing_key=routing_key)
# NOTE(sileht): no need to wait more, caller expects
# a answer before timeout is reached
transport_timeout = timeout
heartbeat_timeout = self.heartbeat_timeout_threshold
if (self._heartbeat_supported_and_enabled() and (
transport_timeout is None or
transport_timeout > heartbeat_timeout)):
# NOTE(sileht): we are supposed to send heartbeat every
# heartbeat_timeout, no need to wait more otherwise will
# disconnect us, so raise timeout earlier ourself
transport_timeout = heartbeat_timeout
log_info = {'msg': msg, log_info = {'msg': msg,
'who': exchange or 'default', 'who': exchange or 'default',
'key': routing_key} 'key': routing_key}
LOG.trace('Connection._publish: sending message %(msg)s to' LOG.trace('Connection._publish: sending message %(msg)s to'
' %(who)s with routing key %(key)s', log_info) ' %(who)s with routing key %(key)s', log_info)
with self._transport_socket_timeout(transport_timeout):
# NOTE(sileht): no need to wait more, caller expects
# a answer before timeout is reached
with self._transport_socket_timeout(timeout):
producer.publish(msg, expiration=self._get_expiration(timeout), producer.publish(msg, expiration=self._get_expiration(timeout),
compression=self.kombu_compression) compression=self.kombu_compression)

View File

@ -92,11 +92,11 @@ class TestHeartbeat(test_utils.BaseTestCase):
if not heartbeat_side_effect: if not heartbeat_side_effect:
self.assertEqual(1, fake_ensure_connection.call_count) self.assertEqual(1, fake_ensure_connection.call_count)
self.assertEqual(3, fake_logger.debug.call_count) self.assertEqual(2, fake_logger.debug.call_count)
self.assertEqual(0, fake_logger.info.call_count) self.assertEqual(0, fake_logger.info.call_count)
else: else:
self.assertEqual(2, fake_ensure_connection.call_count) self.assertEqual(2, fake_ensure_connection.call_count)
self.assertEqual(3, fake_logger.debug.call_count) self.assertEqual(2, fake_logger.debug.call_count)
self.assertEqual(1, fake_logger.info.call_count) self.assertEqual(1, fake_logger.info.call_count)
self.assertIn(mock.call(info, mock.ANY), self.assertIn(mock.call(info, mock.ANY),
fake_logger.info.mock_calls) fake_logger.info.mock_calls)