# Copyright (C) 2014 eNovance SAS <licensing@enovance.com>
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import operator
import random
import threading
import time

import mock
try:
    import qpid
except ImportError:
    qpid = None
from six.moves import _thread
import testscenarios
import testtools

from oslo import messaging
from oslo.messaging._drivers import impl_qpid as qpid_driver
from tests import utils as test_utils


load_tests = testscenarios.load_tests_apply_scenarios

QPID_BROKER = 'localhost:5672'


class TestQpidDriverLoad(test_utils.BaseTestCase):

    def setUp(self):
        super(TestQpidDriverLoad, self).setUp()
        self.messaging_conf.transport_driver = 'qpid'

    def test_driver_load(self):
        transport = messaging.get_transport(self.conf)
        self.assertIsInstance(transport._driver, qpid_driver.QpidDriver)


def _is_qpidd_service_running():

    """this function checks if the qpid service is running or not."""

    qpid_running = True
    try:
        broker = QPID_BROKER
        connection = qpid.messaging.Connection(broker)
        connection.open()
    except Exception:
        # qpid service is not running.
        qpid_running = False
    else:
        connection.close()

    return qpid_running


class _QpidBaseTestCase(test_utils.BaseTestCase):

    @testtools.skipIf(qpid is None, "qpid not available")
    def setUp(self):
        super(_QpidBaseTestCase, self).setUp()
        self.messaging_conf.transport_driver = 'qpid'
        self.fake_qpid = not _is_qpidd_service_running()

        if self.fake_qpid:
            self.session_receive = get_fake_qpid_session()
            self.session_send = get_fake_qpid_session()
        else:
            self.broker = QPID_BROKER
            # create connection from the qpid.messaging
            # connection for the Consumer.
            self.con_receive = qpid.messaging.Connection(self.broker)
            self.con_receive.open()
            # session to receive the messages
            self.session_receive = self.con_receive.session()

            # connection for sending the message
            self.con_send = qpid.messaging.Connection(self.broker)
            self.con_send.open()
            # session to send the messages
            self.session_send = self.con_send.session()

        # list to store the expected messages and
        # the actual received messages
        self._expected = []
        self._messages = []
        self.initialized = True

    def tearDown(self):
        super(_QpidBaseTestCase, self).tearDown()

        if self.initialized:
            if self.fake_qpid:
                _fake_session.flush_exchanges()
            else:
                self.con_receive.close()
                self.con_send.close()


class TestQpidTransportURL(_QpidBaseTestCase):

    scenarios = [
        ('none', dict(url=None,
                      expected=[dict(host='localhost:5672',
                                     username='',
                                     password='')])),
        ('empty',
         dict(url='qpid:///',
              expected=[dict(host='localhost:5672',
                             username='',
                             password='')])),
        ('localhost',
         dict(url='qpid://localhost/',
              expected=[dict(host='localhost',
                             username='',
                             password='')])),
        ('no_creds',
         dict(url='qpid://host/',
              expected=[dict(host='host',
                             username='',
                             password='')])),
        ('no_port',
         dict(url='qpid://user:password@host/',
              expected=[dict(host='host',
                             username='user',
                             password='password')])),
        ('full_url',
         dict(url='qpid://user:password@host:10/',
              expected=[dict(host='host:10',
                             username='user',
                             password='password')])),
        ('full_two_url',
         dict(url='qpid://user:password@host:10,'
              'user2:password2@host2:12/',
              expected=[dict(host='host:10',
                             username='user',
                             password='password'),
                        dict(host='host2:12',
                             username='user2',
                             password='password2')
                        ]
              )),

    ]

    @mock.patch.object(qpid_driver.Connection, 'reconnect')
    def test_transport_url(self, *args):
        transport = messaging.get_transport(self.conf, self.url)
        self.addCleanup(transport.cleanup)
        driver = transport._driver

        brokers_params = driver._get_connection().brokers_params
        self.assertEqual(sorted(self.expected,
                                key=operator.itemgetter('host')),
                         sorted(brokers_params,
                                key=operator.itemgetter('host')))


class TestQpidInvalidTopologyVersion(_QpidBaseTestCase):
    """Unit test cases to test invalid qpid topology version."""

    scenarios = [
        ('direct', dict(consumer_cls=qpid_driver.DirectConsumer,
                        consumer_kwargs={},
                        publisher_cls=qpid_driver.DirectPublisher,
                        publisher_kwargs={})),
        ('topic', dict(consumer_cls=qpid_driver.TopicConsumer,
                       consumer_kwargs={'exchange_name': 'openstack'},
                       publisher_cls=qpid_driver.TopicPublisher,
                       publisher_kwargs={'exchange_name': 'openstack'})),
        ('fanout', dict(consumer_cls=qpid_driver.FanoutConsumer,
                        consumer_kwargs={},
                        publisher_cls=qpid_driver.FanoutPublisher,
                        publisher_kwargs={})),
    ]

    def setUp(self):
        super(TestQpidInvalidTopologyVersion, self).setUp()
        self.config(qpid_topology_version=-1)

    def test_invalid_topology_version(self):
        def consumer_callback(msg):
            pass

        msgid_or_topic = 'test'

        # not using self.assertRaises because
        # 1. qpid driver raises Exception(msg) for invalid topology version
        # 2. flake8 - H202 assertRaises Exception too broad
        exception_msg = ("Invalid value for qpid_topology_version: %d" %
                         self.conf.qpid_topology_version)
        recvd_exc_msg = ''

        try:
            self.consumer_cls(self.conf,
                              self.session_receive,
                              msgid_or_topic,
                              consumer_callback,
                              **self.consumer_kwargs)
        except Exception as e:
            recvd_exc_msg = e.message

        self.assertEqual(exception_msg, recvd_exc_msg)

        recvd_exc_msg = ''
        try:
            self.publisher_cls(self.conf,
                               self.session_send,
                               topic=msgid_or_topic,
                               **self.publisher_kwargs)
        except Exception as e:
            recvd_exc_msg = e.message

        self.assertEqual(exception_msg, recvd_exc_msg)


class TestQpidDirectConsumerPublisher(_QpidBaseTestCase):
    """Unit test cases to test DirectConsumer and Direct Publisher."""

    _n_qpid_topology = [
        ('v1', dict(qpid_topology=1)),
        ('v2', dict(qpid_topology=2)),
    ]

    _n_msgs = [
        ('single', dict(no_msgs=1)),
        ('multiple', dict(no_msgs=10)),
    ]

    @classmethod
    def generate_scenarios(cls):
        cls.scenarios = testscenarios.multiply_scenarios(cls._n_qpid_topology,
                                                         cls._n_msgs)

    def consumer_callback(self, msg):
        # This function will be called by the DirectConsumer
        # when any message is received.
        # Append the received message into the messages list
        # so that the received messages can be validated
        # with the expected messages
        if isinstance(msg, dict):
            self._messages.append(msg['content'])
        else:
            self._messages.append(msg)

    def test_qpid_direct_consumer_producer(self):
        self.msgid = str(random.randint(1, 100))

        # create a DirectConsumer and DirectPublisher class objects
        self.dir_cons = qpid_driver.DirectConsumer(self.conf,
                                                   self.session_receive,
                                                   self.msgid,
                                                   self.consumer_callback)
        self.dir_pub = qpid_driver.DirectPublisher(self.conf,
                                                   self.session_send,
                                                   self.msgid)

        def try_send_msg(no_msgs):
            for i in range(no_msgs):
                self._expected.append(str(i))
                snd_msg = {'content_type': 'text/plain', 'content': str(i)}
                self.dir_pub.send(snd_msg)

        def try_receive_msg(no_msgs):
            for i in range(no_msgs):
                self.dir_cons.consume()

        thread1 = threading.Thread(target=try_receive_msg,
                                   args=(self.no_msgs,))
        thread2 = threading.Thread(target=try_send_msg,
                                   args=(self.no_msgs,))

        thread1.start()
        thread2.start()
        thread1.join()
        thread2.join()

        self.assertEqual(self.no_msgs, len(self._messages))
        self.assertEqual(self._expected, self._messages)


TestQpidDirectConsumerPublisher.generate_scenarios()


class TestQpidTopicAndFanout(_QpidBaseTestCase):
    """Unit Test cases to test TopicConsumer and
    TopicPublisher classes of the qpid driver
    and FanoutConsumer and FanoutPublisher classes
    of the qpid driver
    """

    _n_qpid_topology = [
        ('v1', dict(qpid_topology=1)),
        ('v2', dict(qpid_topology=2)),
    ]

    _n_msgs = [
        ('single', dict(no_msgs=1)),
        ('multiple', dict(no_msgs=10)),
    ]

    _n_senders = [
        ('single', dict(no_senders=1)),
        ('multiple', dict(no_senders=10)),
    ]

    _n_receivers = [
        ('single', dict(no_receivers=1)),
    ]
    _exchange_class = [
        ('topic', dict(consumer_cls=qpid_driver.TopicConsumer,
                       consumer_kwargs={'exchange_name': 'openstack'},
                       publisher_cls=qpid_driver.TopicPublisher,
                       publisher_kwargs={'exchange_name': 'openstack'},
                       topic='topictest.test',
                       receive_topic='topictest.test')),
        ('fanout', dict(consumer_cls=qpid_driver.FanoutConsumer,
                        consumer_kwargs={},
                        publisher_cls=qpid_driver.FanoutPublisher,
                        publisher_kwargs={},
                        topic='fanouttest',
                        receive_topic='fanouttest')),
    ]

    @classmethod
    def generate_scenarios(cls):
        cls.scenarios = testscenarios.multiply_scenarios(cls._n_qpid_topology,
                                                         cls._n_msgs,
                                                         cls._n_senders,
                                                         cls._n_receivers,
                                                         cls._exchange_class)

    def setUp(self):
        super(TestQpidTopicAndFanout, self).setUp()

        # to store the expected messages and the
        # actual received messages
        #
        # NOTE(dhellmann): These are dicts, where the base class uses
        # lists.
        self._expected = {}
        self._messages = {}

        self._senders = []
        self._receivers = []

        self._sender_threads = []
        self._receiver_threads = []

    def consumer_callback(self, msg):
        """callback function called by the ConsumerBase class of
        qpid driver.
        Message will be received in the format x-y
        where x is the sender id and y is the msg number of the sender
        extract the sender id 'x' and store the msg 'x-y' with 'x' as
        the key
        """

        if isinstance(msg, dict):
            msgcontent = msg['content']
        else:
            msgcontent = msg

        splitmsg = msgcontent.split('-')
        key = _thread.get_ident()

        if key not in self._messages:
            self._messages[key] = dict()

        tdict = self._messages[key]

        if splitmsg[0] not in tdict:
            tdict[splitmsg[0]] = []

        tdict[splitmsg[0]].append(msgcontent)

    def _try_send_msg(self, sender_id, no_msgs):
        for i in range(no_msgs):
            sendmsg = '%s-%s' % (str(sender_id), str(i))
            key = str(sender_id)
            # Store the message in the self._expected for each sender.
            # This will be used later to
            # validate the test by comparing it with the
            # received messages by all the receivers
            if key not in self._expected:
                self._expected[key] = []
            self._expected[key].append(sendmsg)
            send_dict = {'content_type': 'text/plain', 'content': sendmsg}
            self._senders[sender_id].send(send_dict)

    def _try_receive_msg(self, receiver_id, no_msgs):
        for i in range(self.no_senders * no_msgs):
            no_of_attempts = 0

            # ConsumerBase.consume blocks indefinitely until a message
            # is received.
            # So qpid_receiver.available() is called before calling
            # ConsumerBase.consume() so that we are not
            # blocked indefinitely
            qpid_receiver = self._receivers[receiver_id].get_receiver()
            while no_of_attempts < 50:
                if qpid_receiver.available() > 0:
                    self._receivers[receiver_id].consume()
                    break
                no_of_attempts += 1
                time.sleep(0.05)

    def test_qpid_topic_and_fanout(self):
        for receiver_id in range(self.no_receivers):
            consumer = self.consumer_cls(self.conf,
                                         self.session_receive,
                                         self.receive_topic,
                                         self.consumer_callback,
                                         **self.consumer_kwargs)
            self._receivers.append(consumer)

            # create receivers threads
            thread = threading.Thread(target=self._try_receive_msg,
                                      args=(receiver_id, self.no_msgs,))
            self._receiver_threads.append(thread)

        for sender_id in range(self.no_senders):
            publisher = self.publisher_cls(self.conf,
                                           self.session_send,
                                           topic=self.topic,
                                           **self.publisher_kwargs)
            self._senders.append(publisher)

            # create sender threads
            thread = threading.Thread(target=self._try_send_msg,
                                      args=(sender_id, self.no_msgs,))
            self._sender_threads.append(thread)

        for thread in self._receiver_threads:
                thread.start()

        for thread in self._sender_threads:
                thread.start()

        for thread in self._receiver_threads:
                thread.join()

        for thread in self._sender_threads:
                thread.join()

        # Each receiver should receive all the messages sent by
        # the sender(s).
        # So, Iterate through each of the receiver items in
        # self._messages and compare with the expected messages
        # messages.

        self.assertEqual(self.no_senders, len(self._expected))
        self.assertEqual(self.no_receivers, len(self._messages))

        for key, messages in self._messages.iteritems():
            self.assertEqual(self._expected, messages)

TestQpidTopicAndFanout.generate_scenarios()


class AddressNodeMatcher(object):
    def __init__(self, node):
        self.node = node

    def __eq__(self, address):
        return address.split(';')[0].strip() == self.node


class TestDriverInterface(_QpidBaseTestCase):
    """Unit Test cases to test the amqpdriver with qpid
    """

    def setUp(self):
        super(TestDriverInterface, self).setUp()
        self.config(qpid_topology_version=2)
        transport = messaging.get_transport(self.conf)
        self.driver = transport._driver

        original_get_connection = self.driver._get_connection
        p = mock.patch.object(self.driver, '_get_connection',
                              side_effect=lambda pooled=True:
                              original_get_connection(False))
        p.start()
        self.addCleanup(p.stop)

    def test_listen_and_direct_send(self):
        target = messaging.Target(exchange="exchange_test",
                                  topic="topic_test",
                                  server="server_test")

        with mock.patch('qpid.messaging.Connection') as conn_cls:
            conn = conn_cls.return_value
            session = conn.session.return_value
            session.receiver.side_effect = [mock.Mock(), mock.Mock(),
                                            mock.Mock()]

            listener = self.driver.listen(target)
            listener.conn.direct_send("msg_id", {})

        self.assertEqual(3, len(listener.conn.consumers))

        expected_calls = [
            mock.call(AddressNodeMatcher(
                'amq.topic/topic/exchange_test/topic_test')),
            mock.call(AddressNodeMatcher(
                'amq.topic/topic/exchange_test/topic_test.server_test')),
            mock.call(AddressNodeMatcher('amq.topic/fanout/topic_test')),
        ]
        session.receiver.assert_has_calls(expected_calls)
        session.sender.assert_called_with(
            AddressNodeMatcher("amq.direct/msg_id"))

    def test_send(self):
        target = messaging.Target(exchange="exchange_test",
                                  topic="topic_test",
                                  server="server_test")
        with mock.patch('qpid.messaging.Connection') as conn_cls:
            conn = conn_cls.return_value
            session = conn.session.return_value

            self.driver.send(target, {}, {})
            session.sender.assert_called_with(AddressNodeMatcher(
                "amq.topic/topic/exchange_test/topic_test.server_test"))

    def test_send_notification(self):
        target = messaging.Target(exchange="exchange_test",
                                  topic="topic_test.info")
        with mock.patch('qpid.messaging.Connection') as conn_cls:
            conn = conn_cls.return_value
            session = conn.session.return_value

            self.driver.send_notification(target, {}, {}, "2.0")
            session.sender.assert_called_with(AddressNodeMatcher(
                "amq.topic/topic/exchange_test/topic_test.info"))


class TestQpidReconnectOrder(test_utils.BaseTestCase):
    """Unit Test cases to test reconnection
    """

    @testtools.skipIf(qpid is None, "qpid not available")
    def test_reconnect_order(self):
        brokers = ['host1', 'host2', 'host3', 'host4', 'host5']
        brokers_count = len(brokers)

        self.config(qpid_hosts=brokers)

        with mock.patch('qpid.messaging.Connection') as conn_mock:
            # starting from the first broker in the list
            url = messaging.TransportURL.parse(self.conf, None)
            connection = qpid_driver.Connection(self.conf, url)

            # reconnect will advance to the next broker, one broker per
            # attempt, and then wrap to the start of the list once the end is
            # reached
            for _ in range(brokers_count):
                connection.reconnect()

        expected = []
        for broker in brokers:
            expected.extend([mock.call("%s:5672" % broker),
                             mock.call().open(),
                             mock.call().session(),
                             mock.call().opened(),
                             mock.call().opened().__nonzero__(),
                             mock.call().close()])

        conn_mock.assert_has_calls(expected, any_order=True)


def synchronized(func):
    func.__lock__ = threading.Lock()

    def synced_func(*args, **kws):
        with func.__lock__:
            return func(*args, **kws)

    return synced_func


class FakeQpidMsgManager(object):
    def __init__(self):
        self._exchanges = {}

    @synchronized
    def add_exchange(self, exchange):
        if exchange not in self._exchanges:
            self._exchanges[exchange] = {'msgs': [], 'consumers': {}}

    @synchronized
    def add_exchange_consumer(self, exchange, consumer_id):
        exchange_info = self._exchanges[exchange]
        cons_dict = exchange_info['consumers']
        cons_dict[consumer_id] = 0

    @synchronized
    def add_exchange_msg(self, exchange, msg):
        exchange_info = self._exchanges[exchange]
        exchange_info['msgs'].append(msg)

    def get_exchange_msg(self, exchange, index):
        exchange_info = self._exchanges[exchange]
        return exchange_info['msgs'][index]

    def get_no_exch_msgs(self, exchange):
        exchange_info = self._exchanges[exchange]
        return len(exchange_info['msgs'])

    def get_exch_cons_index(self, exchange, consumer_id):
        exchange_info = self._exchanges[exchange]
        cons_dict = exchange_info['consumers']
        return cons_dict[consumer_id]

    @synchronized
    def inc_consumer_index(self, exchange, consumer_id):
        exchange_info = self._exchanges[exchange]
        cons_dict = exchange_info['consumers']
        cons_dict[consumer_id] += 1

_fake_qpid_msg_manager = FakeQpidMsgManager()


class FakeQpidSessionSender(object):
    def __init__(self, session, id, target, options):
        self.session = session
        self.id = id
        self.target = target
        self.options = options

    @synchronized
    def send(self, object, sync=True, timeout=None):
        _fake_qpid_msg_manager.add_exchange_msg(self.target, object)

    def close(self, timeout=None):
        pass


class FakeQpidSessionReceiver(object):

    def __init__(self, session, id, source, options):
        self.session = session
        self.id = id
        self.source = source
        self.options = options

    @synchronized
    def fetch(self, timeout=None):
        if timeout is None:
            # if timeout is not given, take a default time out
            # of 30 seconds to avoid indefinite loop
            _timeout = 30
        else:
            _timeout = timeout

        deadline = time.time() + _timeout
        while time.time() <= deadline:
            index = _fake_qpid_msg_manager.get_exch_cons_index(self.source,
                                                               self.id)
            try:
                msg = _fake_qpid_msg_manager.get_exchange_msg(self.source,
                                                              index)
            except IndexError:
                pass
            else:
                _fake_qpid_msg_manager.inc_consumer_index(self.source,
                                                          self.id)
                return qpid.messaging.Message(msg)
            time.sleep(0.050)

        if timeout is None:
            raise Exception('timed out waiting for reply')

    def close(self, timeout=None):
        pass

    @synchronized
    def available(self):
        no_msgs = _fake_qpid_msg_manager.get_no_exch_msgs(self.source)
        index = _fake_qpid_msg_manager.get_exch_cons_index(self.source,
                                                           self.id)
        if no_msgs == 0 or index >= no_msgs:
            return 0
        else:
            return no_msgs - index


class FakeQpidSession(object):

    def __init__(self, connection=None, name=None, transactional=None):
        self.connection = connection
        self.name = name
        self.transactional = transactional
        self._receivers = {}
        self.conf = None
        self.url = None
        self._senders = {}
        self._sender_id = 0
        self._receiver_id = 0

    @synchronized
    def sender(self, target, **options):
        exchange_key = self._extract_exchange_key(target)
        _fake_qpid_msg_manager.add_exchange(exchange_key)

        sendobj = FakeQpidSessionSender(self, self._sender_id,
                                        exchange_key, options)
        self._senders[self._sender_id] = sendobj
        self._sender_id = self._sender_id + 1
        return sendobj

    @synchronized
    def receiver(self, source, **options):
        exchange_key = self._extract_exchange_key(source)
        _fake_qpid_msg_manager.add_exchange(exchange_key)
        recvobj = FakeQpidSessionReceiver(self, self._receiver_id,
                                          exchange_key, options)
        self._receivers[self._receiver_id] = recvobj
        _fake_qpid_msg_manager.add_exchange_consumer(exchange_key,
                                                     self._receiver_id)
        self._receiver_id += 1
        return recvobj

    def acknowledge(self, message=None, disposition=None, sync=True):
        pass

    @synchronized
    def flush_exchanges(self):
        _fake_qpid_msg_manager._exchanges = {}

    def _extract_exchange_key(self, exchange_msg):
        """This function extracts a unique key for the exchange.
        This key is used in the dictionary as a 'key' for
        this exchange.
        Eg. if the exchange_msg (for qpid topology version 1)
        is 33/33 ; {"node": {"x-declare": {"auto-delete": true, ....
        then 33 is returned as the key.
        Eg 2. For topology v2, if the
        exchange_msg is - amq.direct/44 ; {"link": {"x-dec.......
        then 44 is returned
        """
        # first check for ';'
        semicolon_split = exchange_msg.split(';')

        # split the first item of semicolon_split  with '/'
        slash_split = semicolon_split[0].split('/')
        # return the last element of the list as the key
        key = slash_split[-1]
        return key.strip()

    def close(self):
        pass

_fake_session = FakeQpidSession()


def get_fake_qpid_session():
    return _fake_session


class QPidHATestCase(test_utils.BaseTestCase):

    @testtools.skipIf(qpid is None, "qpid not available")
    def setUp(self):
        super(QPidHATestCase, self).setUp()
        self.brokers = ['host1', 'host2', 'host3', 'host4', 'host5']

        self.config(qpid_hosts=self.brokers,
                    qpid_username=None,
                    qpid_password=None)

        hostname_sets = set()
        self.info = {'attempt': 0,
                     'fail': False}

        def _connect(myself, broker):
            # do as little work that is enough to pass connection attempt
            myself.connection = mock.Mock()
            hostname = broker['host']
            self.assertNotIn(hostname, hostname_sets)
            hostname_sets.add(hostname)

            self.info['attempt'] += 1
            if self.info['fail']:
                raise qpid.messaging.exceptions.ConnectionError

        # just make sure connection instantiation does not fail with an
        # exception
        self.stubs.Set(qpid_driver.Connection, '_connect', _connect)

        # starting from the first broker in the list
        url = messaging.TransportURL.parse(self.conf, None)
        self.connection = qpid_driver.Connection(self.conf, url)
        self.addCleanup(self.connection.close)

        self.info.update({'attempt': 0,
                          'fail': True})
        hostname_sets.clear()

    def test_reconnect_order(self):
        self.assertRaises(messaging.MessageDeliveryFailure,
                          self.connection.reconnect,
                          retry=len(self.brokers) - 1)
        self.assertEqual(len(self.brokers), self.info['attempt'])

    def test_ensure_four_retries(self):
        mock_callback = mock.Mock(
            side_effect=qpid.messaging.exceptions.ConnectionError)
        self.assertRaises(messaging.MessageDeliveryFailure,
                          self.connection.ensure, None, mock_callback,
                          retry=4)
        self.assertEqual(5, self.info['attempt'])
        self.assertEqual(1, mock_callback.call_count)

    def test_ensure_one_retry(self):
        mock_callback = mock.Mock(
            side_effect=qpid.messaging.exceptions.ConnectionError)
        self.assertRaises(messaging.MessageDeliveryFailure,
                          self.connection.ensure, None, mock_callback,
                          retry=1)
        self.assertEqual(2, self.info['attempt'])
        self.assertEqual(1, mock_callback.call_count)

    def test_ensure_no_retry(self):
        mock_callback = mock.Mock(
            side_effect=qpid.messaging.exceptions.ConnectionError)
        self.assertRaises(messaging.MessageDeliveryFailure,
                          self.connection.ensure, None, mock_callback,
                          retry=0)
        self.assertEqual(1, self.info['attempt'])
        self.assertEqual(1, mock_callback.call_count)