Merge "Enable user authentication in the AMQP 1.0 driver"

This commit is contained in:
Jenkins 2014-11-19 16:34:00 +00:00 committed by Gerrit Code Review
commit 1a53d3ed30
4 changed files with 121 additions and 36 deletions

View File

@ -25,17 +25,18 @@ functions scheduled by the Controller.
"""
import abc
import collections
import logging
import threading
import uuid
import proton
import pyngus
from six import moves
from oslo.config import cfg
from oslo.messaging._drivers.protocols.amqp import eventloop
from oslo.messaging._drivers.protocols.amqp import opts
from oslo.messaging import transport
LOG = logging.getLogger(__name__)
@ -188,25 +189,24 @@ class Server(pyngus.ReceiverEventHandler):
class Hosts(object):
"""An order list of peer addresses. Connection failover progresses from
one host to the next.
"""An order list of TransportHost addresses. Connection failover
progresses from one host to the next.
"""
HostnamePort = collections.namedtuple('HostnamePort',
['hostname', 'port'])
def __init__(self, entries=None):
self._entries = [self.HostnamePort(h, p) for h, p in entries or []]
self._entries = entries[:] if entries else []
for entry in self._entries:
entry.port = entry.port or 5672
self._current = 0
def add(self, hostname, port=5672):
self._entries.append(self.HostnamePort(hostname, port))
def add(self, transport_host):
self._entries.append(transport_host)
@property
def current(self):
if len(self._entries):
return self._entries[self._current]
else:
return self.HostnamePort("localhost", 5672)
return transport.TransportHost(hostname="localhost", port=5672)
def next(self):
if len(self._entries) > 1:
@ -217,7 +217,7 @@ class Hosts(object):
return '<Hosts ' + str(self) + '>'
def __str__(self):
return ", ".join(["%s:%i" % e for e in self._entries])
return ", ".join(["%r" % th for th in self._entries])
class Controller(pyngus.ConnectionEventHandler):
@ -394,8 +394,7 @@ class Controller(pyngus.ConnectionEventHandler):
def _do_connect(self):
"""Establish connection and reply subscription on processor thread."""
hostname = self.hosts.current.hostname
port = self.hosts.current.port
host = self.hosts.current
conn_props = {}
if self.idle_timeout:
conn_props["idle-time-out"] = float(self.idle_timeout)
@ -412,7 +411,7 @@ class Controller(pyngus.ConnectionEventHandler):
self.ssl_key_file,
self.ssl_key_password)
conn_props["x-ssl-allow-cleartext"] = self.ssl_allow_insecure
self._socket_connection = self.processor.connect(hostname, port,
self._socket_connection = self.processor.connect(host,
handler=self,
properties=conn_props)
LOG.debug("Connection initiated")
@ -535,6 +534,17 @@ class Controller(pyngus.ConnectionEventHandler):
reason or "no reason given")
self._socket_connection.connection.close()
def sasl_done(self, connection, pn_sasl, outcome):
"""This is a Pyngus callback invoked by Pyngus when the SASL handshake
has completed. The outcome of the handshake will be OK on success or
AUTH on failure.
"""
if outcome == proton.SASL.AUTH:
LOG.error("Unable to connect to %s:%s, authentication failure.",
self.hosts.current.hostname, self.hosts.current.port)
# requires user intervention, treat it like a connection failure:
self._handle_connection_loss()
def _complete_shutdown(self):
"""The AMQP Connection has closed, and the driver shutdown is complete.
Clean up controller resources and exit.
@ -574,6 +584,6 @@ class Controller(pyngus.ConnectionEventHandler):
self._reconnecting = False
self._senders = {}
self._socket_connection.reset()
hostname, port = self.hosts.next()
LOG.info("Reconnecting to: %s:%i", hostname, port)
self._socket_connection.connect(hostname, port)
host = self.hosts.next()
LOG.info("Reconnecting to: %s:%i", host.hostname, host.port)
self._socket_connection.connect(host)

View File

@ -190,11 +190,9 @@ class ProtonDriver(base.BaseDriver):
super(ProtonDriver, self).__init__(conf, url, default_exchange,
allowed_remote_exmods)
# TODO(grs): handle authentication etc
hosts = [(h.hostname, h.port or 5672) for h in url.hosts]
# Create a Controller that connects to the messaging service:
self._ctrl = controller.Controller(hosts, default_exchange, conf)
self._ctrl = controller.Controller(url.hosts, default_exchange, conf)
# lazy connection setup - don't cause the controller to connect until
# after the first messaging request:

View File

@ -101,12 +101,12 @@ class _SocketConnection():
self._handler.socket_error(str(e))
return pyngus.Connection.EOS
def connect(self, hostname, port, sasl_mechanisms="ANONYMOUS"):
def connect(self, host):
"""Connect to host:port and start the AMQP protocol."""
addr = socket.getaddrinfo(hostname, port,
addr = socket.getaddrinfo(host.hostname, host.port,
socket.AF_INET, socket.SOCK_STREAM)
if not addr:
key = "%s:%i" % (hostname, port)
key = "%s:%i" % (host.hostname, host.port)
error = "Invalid peer address '%s'" % key
LOG.error(error)
self._handler.socket_error(error)
@ -124,9 +124,14 @@ class _SocketConnection():
return
self.socket = my_socket
if sasl_mechanisms:
pn_sasl = self.connection.pn_sasl
pn_sasl.mechanisms(sasl_mechanisms)
# determine the proper SASL mechanism: PLAIN if a username/password is
# present, else ANONYMOUS
pn_sasl = self.connection.pn_sasl
if host.username:
password = host.password if host.password else ""
pn_sasl.plain(host.username, password)
else:
pn_sasl.mechanisms("ANONYMOUS")
# TODO(kgiusti): server if accepting inbound connections
pn_sasl.client()
self.connection.open()
@ -259,10 +264,9 @@ class Thread(threading.Thread):
LOG.info("eventloop shutdown requested")
self._shutdown = True
def connect(self, hostname, port, handler, properties=None, name=None,
sasl_mechanisms="ANONYMOUS"):
def connect(self, host, handler, properties=None, name=None):
"""Get a _SocketConnection to a peer represented by url."""
key = name or "%s:%i" % (hostname, port)
key = name or "%s:%i" % (host.hostname, host.port)
# return pre-existing
conn = self._container.get_connection(key)
if conn:
@ -273,7 +277,7 @@ class Thread(threading.Thread):
# no name was provided, the host:port combination
sc = _SocketConnection(key, self._container,
properties, handler=handler)
sc.connect(hostname, port, sasl_mechanisms)
sc.connect(host)
return sc
def run(self):

View File

@ -282,6 +282,62 @@ class TestAmqpNotification(_AmqpBrokerTestCase):
driver.cleanup()
@testtools.skipUnless(pyngus, "proton modules not present")
class TestAuthentication(test_utils.BaseTestCase):
def setUp(self):
super(TestAuthentication, self).setUp()
LOG.error("Starting Authentication Test")
# for simplicity, encode the credentials as they would appear 'on the
# wire' in a SASL frame - username and password prefixed by zero.
user_credentials = ["\0joe\0secret"]
self._broker = FakeBroker(sasl_mechanisms="PLAIN",
user_credentials=user_credentials)
self._broker.start()
def tearDown(self):
super(TestAuthentication, self).tearDown()
self._broker.stop()
LOG.error("Authentication Test Ended")
def test_authentication_ok(self):
"""Verify that username and password given in TransportHost are
accepted by the broker.
"""
addr = "amqp://joe:secret@%s:%d" % (self._broker.host,
self._broker.port)
url = messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url)
target = messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1)
rc = driver.send(target, {"context": True},
{"method": "echo"}, wait_for_reply=True)
self.assertIsNotNone(rc)
listener.join(timeout=30)
self.assertFalse(listener.isAlive())
driver.cleanup()
def test_authentication_failure(self):
"""Verify that a bad password given in TransportHost is
rejected by the broker.
"""
addr = "amqp://joe:badpass@%s:%d" % (self._broker.host,
self._broker.port)
url = messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url)
target = messaging.Target(topic="test-topic")
_ListenerThread(driver.listen(target), 1)
self.assertRaises(messaging.MessagingTimeout,
driver.send,
target, {"context": True},
{"method": "echo"},
wait_for_reply=True,
timeout=2.0)
driver.cleanup()
@testtools.skipUnless(pyngus, "proton modules not present")
class TestFailover(test_utils.BaseTestCase):
@ -362,7 +418,8 @@ class FakeBroker(threading.Thread):
class Connection(pyngus.ConnectionEventHandler):
"""A single AMQP connection."""
def __init__(self, server, socket_, name):
def __init__(self, server, socket_, name,
sasl_mechanisms, user_credentials):
"""Create a Connection using socket_."""
self.socket = socket_
self.name = name
@ -370,8 +427,11 @@ class FakeBroker(threading.Thread):
self.connection = server.container.create_connection(name,
self)
self.connection.user_context = self
self.connection.pn_sasl.mechanisms("ANONYMOUS")
self.connection.pn_sasl.server()
self.sasl_mechanisms = sasl_mechanisms
self.user_credentials = user_credentials
if sasl_mechanisms:
self.connection.pn_sasl.mechanisms(sasl_mechanisms)
self.connection.pn_sasl.server()
self.connection.open()
self.sender_links = set()
self.closed = False
@ -436,7 +496,14 @@ class FakeBroker(threading.Thread):
link_handle, addr)
def sasl_step(self, connection, pn_sasl):
pn_sasl.done(pn_sasl.OK) # always permit
if self.sasl_mechanisms == 'PLAIN':
credentials = pn_sasl.recv()
if not credentials:
return # wait until some arrives
if credentials not in self.user_credentials:
# failed
return pn_sasl.done(pn_sasl.AUTH)
pn_sasl.done(pn_sasl.OK)
class SenderLink(pyngus.SenderEventHandler):
"""An AMQP sending link."""
@ -513,7 +580,9 @@ class FakeBroker(threading.Thread):
broadcast_prefix="broadcast",
group_prefix="unicast",
address_separator=".",
sock_addr="", sock_port=0):
sock_addr="", sock_port=0,
sasl_mechanisms="ANONYMOUS",
user_credentials=None):
"""Create a fake broker listening on sock_addr:sock_port."""
if not pyngus:
raise AssertionError("pyngus module not present")
@ -522,6 +591,8 @@ class FakeBroker(threading.Thread):
self._broadcast_prefix = broadcast_prefix + address_separator
self._group_prefix = group_prefix + address_separator
self._address_separator = address_separator
self._sasl_mechanisms = sasl_mechanisms
self._user_credentials = user_credentials
self._wakeup_pipe = os.pipe()
self._my_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._my_socket.bind((sock_addr, sock_port))
@ -581,7 +652,9 @@ class FakeBroker(threading.Thread):
# create a new Connection for it:
client_socket, client_address = self._my_socket.accept()
name = str(client_address)
conn = FakeBroker.Connection(self, client_socket, name)
conn = FakeBroker.Connection(self, client_socket, name,
self._sasl_mechanisms,
self._user_credentials)
self._connections[conn.name] = conn
elif r is self._wakeup_pipe[0]:
os.read(self._wakeup_pipe[0], 512)