diff --git a/oslo/messaging/_drivers/protocols/amqp/controller.py b/oslo/messaging/_drivers/protocols/amqp/controller.py index 12d4b29bf..7bb5493b0 100644 --- a/oslo/messaging/_drivers/protocols/amqp/controller.py +++ b/oslo/messaging/_drivers/protocols/amqp/controller.py @@ -25,17 +25,18 @@ functions scheduled by the Controller. """ import abc -import collections import logging import threading import uuid +import proton import pyngus from six import moves from oslo.config import cfg from oslo.messaging._drivers.protocols.amqp import eventloop from oslo.messaging._drivers.protocols.amqp import opts +from oslo.messaging import transport LOG = logging.getLogger(__name__) @@ -188,25 +189,24 @@ class Server(pyngus.ReceiverEventHandler): class Hosts(object): - """An order list of peer addresses. Connection failover progresses from - one host to the next. + """An order list of TransportHost addresses. Connection failover + progresses from one host to the next. """ - HostnamePort = collections.namedtuple('HostnamePort', - ['hostname', 'port']) - def __init__(self, entries=None): - self._entries = [self.HostnamePort(h, p) for h, p in entries or []] + self._entries = entries[:] if entries else [] + for entry in self._entries: + entry.port = entry.port or 5672 self._current = 0 - def add(self, hostname, port=5672): - self._entries.append(self.HostnamePort(hostname, port)) + def add(self, transport_host): + self._entries.append(transport_host) @property def current(self): if len(self._entries): return self._entries[self._current] else: - return self.HostnamePort("localhost", 5672) + return transport.TransportHost(hostname="localhost", port=5672) def next(self): if len(self._entries) > 1: @@ -217,7 +217,7 @@ class Hosts(object): return '' def __str__(self): - return ", ".join(["%s:%i" % e for e in self._entries]) + return ", ".join(["%r" % th for th in self._entries]) class Controller(pyngus.ConnectionEventHandler): @@ -394,8 +394,7 @@ class Controller(pyngus.ConnectionEventHandler): def _do_connect(self): """Establish connection and reply subscription on processor thread.""" - hostname = self.hosts.current.hostname - port = self.hosts.current.port + host = self.hosts.current conn_props = {} if self.idle_timeout: conn_props["idle-time-out"] = float(self.idle_timeout) @@ -412,7 +411,7 @@ class Controller(pyngus.ConnectionEventHandler): self.ssl_key_file, self.ssl_key_password) conn_props["x-ssl-allow-cleartext"] = self.ssl_allow_insecure - self._socket_connection = self.processor.connect(hostname, port, + self._socket_connection = self.processor.connect(host, handler=self, properties=conn_props) LOG.debug("Connection initiated") @@ -535,6 +534,17 @@ class Controller(pyngus.ConnectionEventHandler): reason or "no reason given") self._socket_connection.connection.close() + def sasl_done(self, connection, pn_sasl, outcome): + """This is a Pyngus callback invoked by Pyngus when the SASL handshake + has completed. The outcome of the handshake will be OK on success or + AUTH on failure. + """ + if outcome == proton.SASL.AUTH: + LOG.error("Unable to connect to %s:%s, authentication failure.", + self.hosts.current.hostname, self.hosts.current.port) + # requires user intervention, treat it like a connection failure: + self._handle_connection_loss() + def _complete_shutdown(self): """The AMQP Connection has closed, and the driver shutdown is complete. Clean up controller resources and exit. @@ -574,6 +584,6 @@ class Controller(pyngus.ConnectionEventHandler): self._reconnecting = False self._senders = {} self._socket_connection.reset() - hostname, port = self.hosts.next() - LOG.info("Reconnecting to: %s:%i", hostname, port) - self._socket_connection.connect(hostname, port) + host = self.hosts.next() + LOG.info("Reconnecting to: %s:%i", host.hostname, host.port) + self._socket_connection.connect(host) diff --git a/oslo/messaging/_drivers/protocols/amqp/driver.py b/oslo/messaging/_drivers/protocols/amqp/driver.py index 455e32382..c84b90813 100644 --- a/oslo/messaging/_drivers/protocols/amqp/driver.py +++ b/oslo/messaging/_drivers/protocols/amqp/driver.py @@ -190,11 +190,9 @@ class ProtonDriver(base.BaseDriver): super(ProtonDriver, self).__init__(conf, url, default_exchange, allowed_remote_exmods) - # TODO(grs): handle authentication etc - hosts = [(h.hostname, h.port or 5672) for h in url.hosts] # Create a Controller that connects to the messaging service: - self._ctrl = controller.Controller(hosts, default_exchange, conf) + self._ctrl = controller.Controller(url.hosts, default_exchange, conf) # lazy connection setup - don't cause the controller to connect until # after the first messaging request: diff --git a/oslo/messaging/_drivers/protocols/amqp/eventloop.py b/oslo/messaging/_drivers/protocols/amqp/eventloop.py index 806f27100..f3d235a57 100644 --- a/oslo/messaging/_drivers/protocols/amqp/eventloop.py +++ b/oslo/messaging/_drivers/protocols/amqp/eventloop.py @@ -101,12 +101,12 @@ class _SocketConnection(): self._handler.socket_error(str(e)) return pyngus.Connection.EOS - def connect(self, hostname, port, sasl_mechanisms="ANONYMOUS"): + def connect(self, host): """Connect to host:port and start the AMQP protocol.""" - addr = socket.getaddrinfo(hostname, port, + addr = socket.getaddrinfo(host.hostname, host.port, socket.AF_INET, socket.SOCK_STREAM) if not addr: - key = "%s:%i" % (hostname, port) + key = "%s:%i" % (host.hostname, host.port) error = "Invalid peer address '%s'" % key LOG.error(error) self._handler.socket_error(error) @@ -124,9 +124,14 @@ class _SocketConnection(): return self.socket = my_socket - if sasl_mechanisms: - pn_sasl = self.connection.pn_sasl - pn_sasl.mechanisms(sasl_mechanisms) + # determine the proper SASL mechanism: PLAIN if a username/password is + # present, else ANONYMOUS + pn_sasl = self.connection.pn_sasl + if host.username: + password = host.password if host.password else "" + pn_sasl.plain(host.username, password) + else: + pn_sasl.mechanisms("ANONYMOUS") # TODO(kgiusti): server if accepting inbound connections pn_sasl.client() self.connection.open() @@ -259,10 +264,9 @@ class Thread(threading.Thread): LOG.info("eventloop shutdown requested") self._shutdown = True - def connect(self, hostname, port, handler, properties=None, name=None, - sasl_mechanisms="ANONYMOUS"): + def connect(self, host, handler, properties=None, name=None): """Get a _SocketConnection to a peer represented by url.""" - key = name or "%s:%i" % (hostname, port) + key = name or "%s:%i" % (host.hostname, host.port) # return pre-existing conn = self._container.get_connection(key) if conn: @@ -273,7 +277,7 @@ class Thread(threading.Thread): # no name was provided, the host:port combination sc = _SocketConnection(key, self._container, properties, handler=handler) - sc.connect(hostname, port, sasl_mechanisms) + sc.connect(host) return sc def run(self): diff --git a/tests/test_amqp_driver.py b/tests/test_amqp_driver.py index be613fbb7..64bbd3003 100644 --- a/tests/test_amqp_driver.py +++ b/tests/test_amqp_driver.py @@ -282,6 +282,62 @@ class TestAmqpNotification(_AmqpBrokerTestCase): driver.cleanup() +@testtools.skipUnless(pyngus, "proton modules not present") +class TestAuthentication(test_utils.BaseTestCase): + + def setUp(self): + super(TestAuthentication, self).setUp() + LOG.error("Starting Authentication Test") + # for simplicity, encode the credentials as they would appear 'on the + # wire' in a SASL frame - username and password prefixed by zero. + user_credentials = ["\0joe\0secret"] + self._broker = FakeBroker(sasl_mechanisms="PLAIN", + user_credentials=user_credentials) + self._broker.start() + + def tearDown(self): + super(TestAuthentication, self).tearDown() + self._broker.stop() + LOG.error("Authentication Test Ended") + + def test_authentication_ok(self): + """Verify that username and password given in TransportHost are + accepted by the broker. + """ + + addr = "amqp://joe:secret@%s:%d" % (self._broker.host, + self._broker.port) + url = messaging.TransportURL.parse(self.conf, addr) + driver = amqp_driver.ProtonDriver(self.conf, url) + target = messaging.Target(topic="test-topic") + listener = _ListenerThread(driver.listen(target), 1) + rc = driver.send(target, {"context": True}, + {"method": "echo"}, wait_for_reply=True) + self.assertIsNotNone(rc) + listener.join(timeout=30) + self.assertFalse(listener.isAlive()) + driver.cleanup() + + def test_authentication_failure(self): + """Verify that a bad password given in TransportHost is + rejected by the broker. + """ + + addr = "amqp://joe:badpass@%s:%d" % (self._broker.host, + self._broker.port) + url = messaging.TransportURL.parse(self.conf, addr) + driver = amqp_driver.ProtonDriver(self.conf, url) + target = messaging.Target(topic="test-topic") + _ListenerThread(driver.listen(target), 1) + self.assertRaises(messaging.MessagingTimeout, + driver.send, + target, {"context": True}, + {"method": "echo"}, + wait_for_reply=True, + timeout=2.0) + driver.cleanup() + + @testtools.skipUnless(pyngus, "proton modules not present") class TestFailover(test_utils.BaseTestCase): @@ -362,7 +418,8 @@ class FakeBroker(threading.Thread): class Connection(pyngus.ConnectionEventHandler): """A single AMQP connection.""" - def __init__(self, server, socket_, name): + def __init__(self, server, socket_, name, + sasl_mechanisms, user_credentials): """Create a Connection using socket_.""" self.socket = socket_ self.name = name @@ -370,8 +427,11 @@ class FakeBroker(threading.Thread): self.connection = server.container.create_connection(name, self) self.connection.user_context = self - self.connection.pn_sasl.mechanisms("ANONYMOUS") - self.connection.pn_sasl.server() + self.sasl_mechanisms = sasl_mechanisms + self.user_credentials = user_credentials + if sasl_mechanisms: + self.connection.pn_sasl.mechanisms(sasl_mechanisms) + self.connection.pn_sasl.server() self.connection.open() self.sender_links = set() self.closed = False @@ -436,7 +496,14 @@ class FakeBroker(threading.Thread): link_handle, addr) def sasl_step(self, connection, pn_sasl): - pn_sasl.done(pn_sasl.OK) # always permit + if self.sasl_mechanisms == 'PLAIN': + credentials = pn_sasl.recv() + if not credentials: + return # wait until some arrives + if credentials not in self.user_credentials: + # failed + return pn_sasl.done(pn_sasl.AUTH) + pn_sasl.done(pn_sasl.OK) class SenderLink(pyngus.SenderEventHandler): """An AMQP sending link.""" @@ -513,7 +580,9 @@ class FakeBroker(threading.Thread): broadcast_prefix="broadcast", group_prefix="unicast", address_separator=".", - sock_addr="", sock_port=0): + sock_addr="", sock_port=0, + sasl_mechanisms="ANONYMOUS", + user_credentials=None): """Create a fake broker listening on sock_addr:sock_port.""" if not pyngus: raise AssertionError("pyngus module not present") @@ -522,6 +591,8 @@ class FakeBroker(threading.Thread): self._broadcast_prefix = broadcast_prefix + address_separator self._group_prefix = group_prefix + address_separator self._address_separator = address_separator + self._sasl_mechanisms = sasl_mechanisms + self._user_credentials = user_credentials self._wakeup_pipe = os.pipe() self._my_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._my_socket.bind((sock_addr, sock_port)) @@ -581,7 +652,9 @@ class FakeBroker(threading.Thread): # create a new Connection for it: client_socket, client_address = self._my_socket.accept() name = str(client_address) - conn = FakeBroker.Connection(self, client_socket, name) + conn = FakeBroker.Connection(self, client_socket, name, + self._sasl_mechanisms, + self._user_credentials) self._connections[conn.name] = conn elif r is self._wakeup_pipe[0]: os.read(self._wakeup_pipe[0], 512)