Enable user authentication in the AMQP 1.0 driver
The TransportHost class allows user credentials to be supplied as part of the URL that identifies the host. Prior to this patch, these credentials - username and password - were ignored by the AMQP 1.0 driver. This prevents connections to a message broker that has been configured to use SASL PLAIN authentication. Closes-Bug: #1385445 Change-Id: Ib9279ed40b0f4cff62e1c742069c8f49f5625659
This commit is contained in:
parent
70910e0c9b
commit
f43fe66be0
@ -25,17 +25,18 @@ functions scheduled by the Controller.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
import proton
|
||||
import pyngus
|
||||
from six import moves
|
||||
|
||||
from oslo.config import cfg
|
||||
from oslo.messaging._drivers.protocols.amqp import eventloop
|
||||
from oslo.messaging._drivers.protocols.amqp import opts
|
||||
from oslo.messaging import transport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -188,25 +189,24 @@ class Server(pyngus.ReceiverEventHandler):
|
||||
|
||||
|
||||
class Hosts(object):
|
||||
"""An order list of peer addresses. Connection failover progresses from
|
||||
one host to the next.
|
||||
"""An order list of TransportHost addresses. Connection failover
|
||||
progresses from one host to the next.
|
||||
"""
|
||||
HostnamePort = collections.namedtuple('HostnamePort',
|
||||
['hostname', 'port'])
|
||||
|
||||
def __init__(self, entries=None):
|
||||
self._entries = [self.HostnamePort(h, p) for h, p in entries or []]
|
||||
self._entries = entries[:] if entries else []
|
||||
for entry in self._entries:
|
||||
entry.port = entry.port or 5672
|
||||
self._current = 0
|
||||
|
||||
def add(self, hostname, port=5672):
|
||||
self._entries.append(self.HostnamePort(hostname, port))
|
||||
def add(self, transport_host):
|
||||
self._entries.append(transport_host)
|
||||
|
||||
@property
|
||||
def current(self):
|
||||
if len(self._entries):
|
||||
return self._entries[self._current]
|
||||
else:
|
||||
return self.HostnamePort("localhost", 5672)
|
||||
return transport.TransportHost(hostname="localhost", port=5672)
|
||||
|
||||
def next(self):
|
||||
if len(self._entries) > 1:
|
||||
@ -217,7 +217,7 @@ class Hosts(object):
|
||||
return '<Hosts ' + str(self) + '>'
|
||||
|
||||
def __str__(self):
|
||||
return ", ".join(["%s:%i" % e for e in self._entries])
|
||||
return ", ".join(["%r" % th for th in self._entries])
|
||||
|
||||
|
||||
class Controller(pyngus.ConnectionEventHandler):
|
||||
@ -394,8 +394,7 @@ class Controller(pyngus.ConnectionEventHandler):
|
||||
|
||||
def _do_connect(self):
|
||||
"""Establish connection and reply subscription on processor thread."""
|
||||
hostname = self.hosts.current.hostname
|
||||
port = self.hosts.current.port
|
||||
host = self.hosts.current
|
||||
conn_props = {}
|
||||
if self.idle_timeout:
|
||||
conn_props["idle-time-out"] = float(self.idle_timeout)
|
||||
@ -412,7 +411,7 @@ class Controller(pyngus.ConnectionEventHandler):
|
||||
self.ssl_key_file,
|
||||
self.ssl_key_password)
|
||||
conn_props["x-ssl-allow-cleartext"] = self.ssl_allow_insecure
|
||||
self._socket_connection = self.processor.connect(hostname, port,
|
||||
self._socket_connection = self.processor.connect(host,
|
||||
handler=self,
|
||||
properties=conn_props)
|
||||
LOG.debug("Connection initiated")
|
||||
@ -535,6 +534,17 @@ class Controller(pyngus.ConnectionEventHandler):
|
||||
reason or "no reason given")
|
||||
self._socket_connection.connection.close()
|
||||
|
||||
def sasl_done(self, connection, pn_sasl, outcome):
|
||||
"""This is a Pyngus callback invoked by Pyngus when the SASL handshake
|
||||
has completed. The outcome of the handshake will be OK on success or
|
||||
AUTH on failure.
|
||||
"""
|
||||
if outcome == proton.SASL.AUTH:
|
||||
LOG.error("Unable to connect to %s:%s, authentication failure.",
|
||||
self.hosts.current.hostname, self.hosts.current.port)
|
||||
# requires user intervention, treat it like a connection failure:
|
||||
self._handle_connection_loss()
|
||||
|
||||
def _complete_shutdown(self):
|
||||
"""The AMQP Connection has closed, and the driver shutdown is complete.
|
||||
Clean up controller resources and exit.
|
||||
@ -574,6 +584,6 @@ class Controller(pyngus.ConnectionEventHandler):
|
||||
self._reconnecting = False
|
||||
self._senders = {}
|
||||
self._socket_connection.reset()
|
||||
hostname, port = self.hosts.next()
|
||||
LOG.info("Reconnecting to: %s:%i", hostname, port)
|
||||
self._socket_connection.connect(hostname, port)
|
||||
host = self.hosts.next()
|
||||
LOG.info("Reconnecting to: %s:%i", host.hostname, host.port)
|
||||
self._socket_connection.connect(host)
|
||||
|
@ -190,11 +190,9 @@ class ProtonDriver(base.BaseDriver):
|
||||
|
||||
super(ProtonDriver, self).__init__(conf, url, default_exchange,
|
||||
allowed_remote_exmods)
|
||||
# TODO(grs): handle authentication etc
|
||||
hosts = [(h.hostname, h.port or 5672) for h in url.hosts]
|
||||
|
||||
# Create a Controller that connects to the messaging service:
|
||||
self._ctrl = controller.Controller(hosts, default_exchange, conf)
|
||||
self._ctrl = controller.Controller(url.hosts, default_exchange, conf)
|
||||
|
||||
# lazy connection setup - don't cause the controller to connect until
|
||||
# after the first messaging request:
|
||||
|
@ -101,12 +101,12 @@ class _SocketConnection():
|
||||
self._handler.socket_error(str(e))
|
||||
return pyngus.Connection.EOS
|
||||
|
||||
def connect(self, hostname, port, sasl_mechanisms="ANONYMOUS"):
|
||||
def connect(self, host):
|
||||
"""Connect to host:port and start the AMQP protocol."""
|
||||
addr = socket.getaddrinfo(hostname, port,
|
||||
addr = socket.getaddrinfo(host.hostname, host.port,
|
||||
socket.AF_INET, socket.SOCK_STREAM)
|
||||
if not addr:
|
||||
key = "%s:%i" % (hostname, port)
|
||||
key = "%s:%i" % (host.hostname, host.port)
|
||||
error = "Invalid peer address '%s'" % key
|
||||
LOG.error(error)
|
||||
self._handler.socket_error(error)
|
||||
@ -124,9 +124,14 @@ class _SocketConnection():
|
||||
return
|
||||
self.socket = my_socket
|
||||
|
||||
if sasl_mechanisms:
|
||||
pn_sasl = self.connection.pn_sasl
|
||||
pn_sasl.mechanisms(sasl_mechanisms)
|
||||
# determine the proper SASL mechanism: PLAIN if a username/password is
|
||||
# present, else ANONYMOUS
|
||||
pn_sasl = self.connection.pn_sasl
|
||||
if host.username:
|
||||
password = host.password if host.password else ""
|
||||
pn_sasl.plain(host.username, password)
|
||||
else:
|
||||
pn_sasl.mechanisms("ANONYMOUS")
|
||||
# TODO(kgiusti): server if accepting inbound connections
|
||||
pn_sasl.client()
|
||||
self.connection.open()
|
||||
@ -259,10 +264,9 @@ class Thread(threading.Thread):
|
||||
LOG.info("eventloop shutdown requested")
|
||||
self._shutdown = True
|
||||
|
||||
def connect(self, hostname, port, handler, properties=None, name=None,
|
||||
sasl_mechanisms="ANONYMOUS"):
|
||||
def connect(self, host, handler, properties=None, name=None):
|
||||
"""Get a _SocketConnection to a peer represented by url."""
|
||||
key = name or "%s:%i" % (hostname, port)
|
||||
key = name or "%s:%i" % (host.hostname, host.port)
|
||||
# return pre-existing
|
||||
conn = self._container.get_connection(key)
|
||||
if conn:
|
||||
@ -273,7 +277,7 @@ class Thread(threading.Thread):
|
||||
# no name was provided, the host:port combination
|
||||
sc = _SocketConnection(key, self._container,
|
||||
properties, handler=handler)
|
||||
sc.connect(hostname, port, sasl_mechanisms)
|
||||
sc.connect(host)
|
||||
return sc
|
||||
|
||||
def run(self):
|
||||
|
@ -282,6 +282,62 @@ class TestAmqpNotification(_AmqpBrokerTestCase):
|
||||
driver.cleanup()
|
||||
|
||||
|
||||
@testtools.skipUnless(pyngus, "proton modules not present")
|
||||
class TestAuthentication(test_utils.BaseTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestAuthentication, self).setUp()
|
||||
LOG.error("Starting Authentication Test")
|
||||
# for simplicity, encode the credentials as they would appear 'on the
|
||||
# wire' in a SASL frame - username and password prefixed by zero.
|
||||
user_credentials = ["\0joe\0secret"]
|
||||
self._broker = FakeBroker(sasl_mechanisms="PLAIN",
|
||||
user_credentials=user_credentials)
|
||||
self._broker.start()
|
||||
|
||||
def tearDown(self):
|
||||
super(TestAuthentication, self).tearDown()
|
||||
self._broker.stop()
|
||||
LOG.error("Authentication Test Ended")
|
||||
|
||||
def test_authentication_ok(self):
|
||||
"""Verify that username and password given in TransportHost are
|
||||
accepted by the broker.
|
||||
"""
|
||||
|
||||
addr = "amqp://joe:secret@%s:%d" % (self._broker.host,
|
||||
self._broker.port)
|
||||
url = messaging.TransportURL.parse(self.conf, addr)
|
||||
driver = amqp_driver.ProtonDriver(self.conf, url)
|
||||
target = messaging.Target(topic="test-topic")
|
||||
listener = _ListenerThread(driver.listen(target), 1)
|
||||
rc = driver.send(target, {"context": True},
|
||||
{"method": "echo"}, wait_for_reply=True)
|
||||
self.assertIsNotNone(rc)
|
||||
listener.join(timeout=30)
|
||||
self.assertFalse(listener.isAlive())
|
||||
driver.cleanup()
|
||||
|
||||
def test_authentication_failure(self):
|
||||
"""Verify that a bad password given in TransportHost is
|
||||
rejected by the broker.
|
||||
"""
|
||||
|
||||
addr = "amqp://joe:badpass@%s:%d" % (self._broker.host,
|
||||
self._broker.port)
|
||||
url = messaging.TransportURL.parse(self.conf, addr)
|
||||
driver = amqp_driver.ProtonDriver(self.conf, url)
|
||||
target = messaging.Target(topic="test-topic")
|
||||
_ListenerThread(driver.listen(target), 1)
|
||||
self.assertRaises(messaging.MessagingTimeout,
|
||||
driver.send,
|
||||
target, {"context": True},
|
||||
{"method": "echo"},
|
||||
wait_for_reply=True,
|
||||
timeout=2.0)
|
||||
driver.cleanup()
|
||||
|
||||
|
||||
@testtools.skipUnless(pyngus, "proton modules not present")
|
||||
class TestFailover(test_utils.BaseTestCase):
|
||||
|
||||
@ -362,7 +418,8 @@ class FakeBroker(threading.Thread):
|
||||
class Connection(pyngus.ConnectionEventHandler):
|
||||
"""A single AMQP connection."""
|
||||
|
||||
def __init__(self, server, socket_, name):
|
||||
def __init__(self, server, socket_, name,
|
||||
sasl_mechanisms, user_credentials):
|
||||
"""Create a Connection using socket_."""
|
||||
self.socket = socket_
|
||||
self.name = name
|
||||
@ -370,8 +427,11 @@ class FakeBroker(threading.Thread):
|
||||
self.connection = server.container.create_connection(name,
|
||||
self)
|
||||
self.connection.user_context = self
|
||||
self.connection.pn_sasl.mechanisms("ANONYMOUS")
|
||||
self.connection.pn_sasl.server()
|
||||
self.sasl_mechanisms = sasl_mechanisms
|
||||
self.user_credentials = user_credentials
|
||||
if sasl_mechanisms:
|
||||
self.connection.pn_sasl.mechanisms(sasl_mechanisms)
|
||||
self.connection.pn_sasl.server()
|
||||
self.connection.open()
|
||||
self.sender_links = set()
|
||||
self.closed = False
|
||||
@ -436,7 +496,14 @@ class FakeBroker(threading.Thread):
|
||||
link_handle, addr)
|
||||
|
||||
def sasl_step(self, connection, pn_sasl):
|
||||
pn_sasl.done(pn_sasl.OK) # always permit
|
||||
if self.sasl_mechanisms == 'PLAIN':
|
||||
credentials = pn_sasl.recv()
|
||||
if not credentials:
|
||||
return # wait until some arrives
|
||||
if credentials not in self.user_credentials:
|
||||
# failed
|
||||
return pn_sasl.done(pn_sasl.AUTH)
|
||||
pn_sasl.done(pn_sasl.OK)
|
||||
|
||||
class SenderLink(pyngus.SenderEventHandler):
|
||||
"""An AMQP sending link."""
|
||||
@ -513,7 +580,9 @@ class FakeBroker(threading.Thread):
|
||||
broadcast_prefix="broadcast",
|
||||
group_prefix="unicast",
|
||||
address_separator=".",
|
||||
sock_addr="", sock_port=0):
|
||||
sock_addr="", sock_port=0,
|
||||
sasl_mechanisms="ANONYMOUS",
|
||||
user_credentials=None):
|
||||
"""Create a fake broker listening on sock_addr:sock_port."""
|
||||
if not pyngus:
|
||||
raise AssertionError("pyngus module not present")
|
||||
@ -522,6 +591,8 @@ class FakeBroker(threading.Thread):
|
||||
self._broadcast_prefix = broadcast_prefix + address_separator
|
||||
self._group_prefix = group_prefix + address_separator
|
||||
self._address_separator = address_separator
|
||||
self._sasl_mechanisms = sasl_mechanisms
|
||||
self._user_credentials = user_credentials
|
||||
self._wakeup_pipe = os.pipe()
|
||||
self._my_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._my_socket.bind((sock_addr, sock_port))
|
||||
@ -581,7 +652,9 @@ class FakeBroker(threading.Thread):
|
||||
# create a new Connection for it:
|
||||
client_socket, client_address = self._my_socket.accept()
|
||||
name = str(client_address)
|
||||
conn = FakeBroker.Connection(self, client_socket, name)
|
||||
conn = FakeBroker.Connection(self, client_socket, name,
|
||||
self._sasl_mechanisms,
|
||||
self._user_credentials)
|
||||
self._connections[conn.name] = conn
|
||||
elif r is self._wakeup_pipe[0]:
|
||||
os.read(self._wakeup_pipe[0], 512)
|
||||
|
Loading…
Reference in New Issue
Block a user