From f43fe66be05cbad316a835f3308188515a93b6eb Mon Sep 17 00:00:00 2001 From: Kenneth Giusti Date: Wed, 29 Oct 2014 17:39:40 -0400 Subject: [PATCH] Enable user authentication in the AMQP 1.0 driver The TransportHost class allows user credentials to be supplied as part of the URL that identifies the host. Prior to this patch, these credentials - username and password - were ignored by the AMQP 1.0 driver. This prevents connections to a message broker that has been configured to use SASL PLAIN authentication. Closes-Bug: #1385445 Change-Id: Ib9279ed40b0f4cff62e1c742069c8f49f5625659 --- .../_drivers/protocols/amqp/controller.py | 44 ++++++---- .../_drivers/protocols/amqp/driver.py | 4 +- .../_drivers/protocols/amqp/eventloop.py | 24 +++--- tests/test_amqp_driver.py | 85 +++++++++++++++++-- 4 files changed, 121 insertions(+), 36 deletions(-) diff --git a/oslo/messaging/_drivers/protocols/amqp/controller.py b/oslo/messaging/_drivers/protocols/amqp/controller.py index 12d4b29bf..7bb5493b0 100644 --- a/oslo/messaging/_drivers/protocols/amqp/controller.py +++ b/oslo/messaging/_drivers/protocols/amqp/controller.py @@ -25,17 +25,18 @@ functions scheduled by the Controller. """ import abc -import collections import logging import threading import uuid +import proton import pyngus from six import moves from oslo.config import cfg from oslo.messaging._drivers.protocols.amqp import eventloop from oslo.messaging._drivers.protocols.amqp import opts +from oslo.messaging import transport LOG = logging.getLogger(__name__) @@ -188,25 +189,24 @@ class Server(pyngus.ReceiverEventHandler): class Hosts(object): - """An order list of peer addresses. Connection failover progresses from - one host to the next. + """An order list of TransportHost addresses. Connection failover + progresses from one host to the next. """ - HostnamePort = collections.namedtuple('HostnamePort', - ['hostname', 'port']) - def __init__(self, entries=None): - self._entries = [self.HostnamePort(h, p) for h, p in entries or []] + self._entries = entries[:] if entries else [] + for entry in self._entries: + entry.port = entry.port or 5672 self._current = 0 - def add(self, hostname, port=5672): - self._entries.append(self.HostnamePort(hostname, port)) + def add(self, transport_host): + self._entries.append(transport_host) @property def current(self): if len(self._entries): return self._entries[self._current] else: - return self.HostnamePort("localhost", 5672) + return transport.TransportHost(hostname="localhost", port=5672) def next(self): if len(self._entries) > 1: @@ -217,7 +217,7 @@ class Hosts(object): return '' def __str__(self): - return ", ".join(["%s:%i" % e for e in self._entries]) + return ", ".join(["%r" % th for th in self._entries]) class Controller(pyngus.ConnectionEventHandler): @@ -394,8 +394,7 @@ class Controller(pyngus.ConnectionEventHandler): def _do_connect(self): """Establish connection and reply subscription on processor thread.""" - hostname = self.hosts.current.hostname - port = self.hosts.current.port + host = self.hosts.current conn_props = {} if self.idle_timeout: conn_props["idle-time-out"] = float(self.idle_timeout) @@ -412,7 +411,7 @@ class Controller(pyngus.ConnectionEventHandler): self.ssl_key_file, self.ssl_key_password) conn_props["x-ssl-allow-cleartext"] = self.ssl_allow_insecure - self._socket_connection = self.processor.connect(hostname, port, + self._socket_connection = self.processor.connect(host, handler=self, properties=conn_props) LOG.debug("Connection initiated") @@ -535,6 +534,17 @@ class Controller(pyngus.ConnectionEventHandler): reason or "no reason given") self._socket_connection.connection.close() + def sasl_done(self, connection, pn_sasl, outcome): + """This is a Pyngus callback invoked by Pyngus when the SASL handshake + has completed. The outcome of the handshake will be OK on success or + AUTH on failure. + """ + if outcome == proton.SASL.AUTH: + LOG.error("Unable to connect to %s:%s, authentication failure.", + self.hosts.current.hostname, self.hosts.current.port) + # requires user intervention, treat it like a connection failure: + self._handle_connection_loss() + def _complete_shutdown(self): """The AMQP Connection has closed, and the driver shutdown is complete. Clean up controller resources and exit. @@ -574,6 +584,6 @@ class Controller(pyngus.ConnectionEventHandler): self._reconnecting = False self._senders = {} self._socket_connection.reset() - hostname, port = self.hosts.next() - LOG.info("Reconnecting to: %s:%i", hostname, port) - self._socket_connection.connect(hostname, port) + host = self.hosts.next() + LOG.info("Reconnecting to: %s:%i", host.hostname, host.port) + self._socket_connection.connect(host) diff --git a/oslo/messaging/_drivers/protocols/amqp/driver.py b/oslo/messaging/_drivers/protocols/amqp/driver.py index 455e32382..c84b90813 100644 --- a/oslo/messaging/_drivers/protocols/amqp/driver.py +++ b/oslo/messaging/_drivers/protocols/amqp/driver.py @@ -190,11 +190,9 @@ class ProtonDriver(base.BaseDriver): super(ProtonDriver, self).__init__(conf, url, default_exchange, allowed_remote_exmods) - # TODO(grs): handle authentication etc - hosts = [(h.hostname, h.port or 5672) for h in url.hosts] # Create a Controller that connects to the messaging service: - self._ctrl = controller.Controller(hosts, default_exchange, conf) + self._ctrl = controller.Controller(url.hosts, default_exchange, conf) # lazy connection setup - don't cause the controller to connect until # after the first messaging request: diff --git a/oslo/messaging/_drivers/protocols/amqp/eventloop.py b/oslo/messaging/_drivers/protocols/amqp/eventloop.py index 806f27100..f3d235a57 100644 --- a/oslo/messaging/_drivers/protocols/amqp/eventloop.py +++ b/oslo/messaging/_drivers/protocols/amqp/eventloop.py @@ -101,12 +101,12 @@ class _SocketConnection(): self._handler.socket_error(str(e)) return pyngus.Connection.EOS - def connect(self, hostname, port, sasl_mechanisms="ANONYMOUS"): + def connect(self, host): """Connect to host:port and start the AMQP protocol.""" - addr = socket.getaddrinfo(hostname, port, + addr = socket.getaddrinfo(host.hostname, host.port, socket.AF_INET, socket.SOCK_STREAM) if not addr: - key = "%s:%i" % (hostname, port) + key = "%s:%i" % (host.hostname, host.port) error = "Invalid peer address '%s'" % key LOG.error(error) self._handler.socket_error(error) @@ -124,9 +124,14 @@ class _SocketConnection(): return self.socket = my_socket - if sasl_mechanisms: - pn_sasl = self.connection.pn_sasl - pn_sasl.mechanisms(sasl_mechanisms) + # determine the proper SASL mechanism: PLAIN if a username/password is + # present, else ANONYMOUS + pn_sasl = self.connection.pn_sasl + if host.username: + password = host.password if host.password else "" + pn_sasl.plain(host.username, password) + else: + pn_sasl.mechanisms("ANONYMOUS") # TODO(kgiusti): server if accepting inbound connections pn_sasl.client() self.connection.open() @@ -259,10 +264,9 @@ class Thread(threading.Thread): LOG.info("eventloop shutdown requested") self._shutdown = True - def connect(self, hostname, port, handler, properties=None, name=None, - sasl_mechanisms="ANONYMOUS"): + def connect(self, host, handler, properties=None, name=None): """Get a _SocketConnection to a peer represented by url.""" - key = name or "%s:%i" % (hostname, port) + key = name or "%s:%i" % (host.hostname, host.port) # return pre-existing conn = self._container.get_connection(key) if conn: @@ -273,7 +277,7 @@ class Thread(threading.Thread): # no name was provided, the host:port combination sc = _SocketConnection(key, self._container, properties, handler=handler) - sc.connect(hostname, port, sasl_mechanisms) + sc.connect(host) return sc def run(self): diff --git a/tests/test_amqp_driver.py b/tests/test_amqp_driver.py index be613fbb7..64bbd3003 100644 --- a/tests/test_amqp_driver.py +++ b/tests/test_amqp_driver.py @@ -282,6 +282,62 @@ class TestAmqpNotification(_AmqpBrokerTestCase): driver.cleanup() +@testtools.skipUnless(pyngus, "proton modules not present") +class TestAuthentication(test_utils.BaseTestCase): + + def setUp(self): + super(TestAuthentication, self).setUp() + LOG.error("Starting Authentication Test") + # for simplicity, encode the credentials as they would appear 'on the + # wire' in a SASL frame - username and password prefixed by zero. + user_credentials = ["\0joe\0secret"] + self._broker = FakeBroker(sasl_mechanisms="PLAIN", + user_credentials=user_credentials) + self._broker.start() + + def tearDown(self): + super(TestAuthentication, self).tearDown() + self._broker.stop() + LOG.error("Authentication Test Ended") + + def test_authentication_ok(self): + """Verify that username and password given in TransportHost are + accepted by the broker. + """ + + addr = "amqp://joe:secret@%s:%d" % (self._broker.host, + self._broker.port) + url = messaging.TransportURL.parse(self.conf, addr) + driver = amqp_driver.ProtonDriver(self.conf, url) + target = messaging.Target(topic="test-topic") + listener = _ListenerThread(driver.listen(target), 1) + rc = driver.send(target, {"context": True}, + {"method": "echo"}, wait_for_reply=True) + self.assertIsNotNone(rc) + listener.join(timeout=30) + self.assertFalse(listener.isAlive()) + driver.cleanup() + + def test_authentication_failure(self): + """Verify that a bad password given in TransportHost is + rejected by the broker. + """ + + addr = "amqp://joe:badpass@%s:%d" % (self._broker.host, + self._broker.port) + url = messaging.TransportURL.parse(self.conf, addr) + driver = amqp_driver.ProtonDriver(self.conf, url) + target = messaging.Target(topic="test-topic") + _ListenerThread(driver.listen(target), 1) + self.assertRaises(messaging.MessagingTimeout, + driver.send, + target, {"context": True}, + {"method": "echo"}, + wait_for_reply=True, + timeout=2.0) + driver.cleanup() + + @testtools.skipUnless(pyngus, "proton modules not present") class TestFailover(test_utils.BaseTestCase): @@ -362,7 +418,8 @@ class FakeBroker(threading.Thread): class Connection(pyngus.ConnectionEventHandler): """A single AMQP connection.""" - def __init__(self, server, socket_, name): + def __init__(self, server, socket_, name, + sasl_mechanisms, user_credentials): """Create a Connection using socket_.""" self.socket = socket_ self.name = name @@ -370,8 +427,11 @@ class FakeBroker(threading.Thread): self.connection = server.container.create_connection(name, self) self.connection.user_context = self - self.connection.pn_sasl.mechanisms("ANONYMOUS") - self.connection.pn_sasl.server() + self.sasl_mechanisms = sasl_mechanisms + self.user_credentials = user_credentials + if sasl_mechanisms: + self.connection.pn_sasl.mechanisms(sasl_mechanisms) + self.connection.pn_sasl.server() self.connection.open() self.sender_links = set() self.closed = False @@ -436,7 +496,14 @@ class FakeBroker(threading.Thread): link_handle, addr) def sasl_step(self, connection, pn_sasl): - pn_sasl.done(pn_sasl.OK) # always permit + if self.sasl_mechanisms == 'PLAIN': + credentials = pn_sasl.recv() + if not credentials: + return # wait until some arrives + if credentials not in self.user_credentials: + # failed + return pn_sasl.done(pn_sasl.AUTH) + pn_sasl.done(pn_sasl.OK) class SenderLink(pyngus.SenderEventHandler): """An AMQP sending link.""" @@ -513,7 +580,9 @@ class FakeBroker(threading.Thread): broadcast_prefix="broadcast", group_prefix="unicast", address_separator=".", - sock_addr="", sock_port=0): + sock_addr="", sock_port=0, + sasl_mechanisms="ANONYMOUS", + user_credentials=None): """Create a fake broker listening on sock_addr:sock_port.""" if not pyngus: raise AssertionError("pyngus module not present") @@ -522,6 +591,8 @@ class FakeBroker(threading.Thread): self._broadcast_prefix = broadcast_prefix + address_separator self._group_prefix = group_prefix + address_separator self._address_separator = address_separator + self._sasl_mechanisms = sasl_mechanisms + self._user_credentials = user_credentials self._wakeup_pipe = os.pipe() self._my_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._my_socket.bind((sock_addr, sock_port)) @@ -581,7 +652,9 @@ class FakeBroker(threading.Thread): # create a new Connection for it: client_socket, client_address = self._my_socket.accept() name = str(client_address) - conn = FakeBroker.Connection(self, client_socket, name) + conn = FakeBroker.Connection(self, client_socket, name, + self._sasl_mechanisms, + self._user_credentials) self._connections[conn.name] = conn elif r is self._wakeup_pipe[0]: os.read(self._wakeup_pipe[0], 512)