Merge "Enable user authentication in the AMQP 1.0 driver"

This commit is contained in:
Jenkins 2014-11-19 16:34:00 +00:00 committed by Gerrit Code Review
commit 1a53d3ed30
4 changed files with 121 additions and 36 deletions

View File

@ -25,17 +25,18 @@ functions scheduled by the Controller.
""" """
import abc import abc
import collections
import logging import logging
import threading import threading
import uuid import uuid
import proton
import pyngus import pyngus
from six import moves from six import moves
from oslo.config import cfg from oslo.config import cfg
from oslo.messaging._drivers.protocols.amqp import eventloop from oslo.messaging._drivers.protocols.amqp import eventloop
from oslo.messaging._drivers.protocols.amqp import opts from oslo.messaging._drivers.protocols.amqp import opts
from oslo.messaging import transport
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -188,25 +189,24 @@ class Server(pyngus.ReceiverEventHandler):
class Hosts(object): class Hosts(object):
"""An order list of peer addresses. Connection failover progresses from """An order list of TransportHost addresses. Connection failover
one host to the next. progresses from one host to the next.
""" """
HostnamePort = collections.namedtuple('HostnamePort',
['hostname', 'port'])
def __init__(self, entries=None): def __init__(self, entries=None):
self._entries = [self.HostnamePort(h, p) for h, p in entries or []] self._entries = entries[:] if entries else []
for entry in self._entries:
entry.port = entry.port or 5672
self._current = 0 self._current = 0
def add(self, hostname, port=5672): def add(self, transport_host):
self._entries.append(self.HostnamePort(hostname, port)) self._entries.append(transport_host)
@property @property
def current(self): def current(self):
if len(self._entries): if len(self._entries):
return self._entries[self._current] return self._entries[self._current]
else: else:
return self.HostnamePort("localhost", 5672) return transport.TransportHost(hostname="localhost", port=5672)
def next(self): def next(self):
if len(self._entries) > 1: if len(self._entries) > 1:
@ -217,7 +217,7 @@ class Hosts(object):
return '<Hosts ' + str(self) + '>' return '<Hosts ' + str(self) + '>'
def __str__(self): def __str__(self):
return ", ".join(["%s:%i" % e for e in self._entries]) return ", ".join(["%r" % th for th in self._entries])
class Controller(pyngus.ConnectionEventHandler): class Controller(pyngus.ConnectionEventHandler):
@ -394,8 +394,7 @@ class Controller(pyngus.ConnectionEventHandler):
def _do_connect(self): def _do_connect(self):
"""Establish connection and reply subscription on processor thread.""" """Establish connection and reply subscription on processor thread."""
hostname = self.hosts.current.hostname host = self.hosts.current
port = self.hosts.current.port
conn_props = {} conn_props = {}
if self.idle_timeout: if self.idle_timeout:
conn_props["idle-time-out"] = float(self.idle_timeout) conn_props["idle-time-out"] = float(self.idle_timeout)
@ -412,7 +411,7 @@ class Controller(pyngus.ConnectionEventHandler):
self.ssl_key_file, self.ssl_key_file,
self.ssl_key_password) self.ssl_key_password)
conn_props["x-ssl-allow-cleartext"] = self.ssl_allow_insecure conn_props["x-ssl-allow-cleartext"] = self.ssl_allow_insecure
self._socket_connection = self.processor.connect(hostname, port, self._socket_connection = self.processor.connect(host,
handler=self, handler=self,
properties=conn_props) properties=conn_props)
LOG.debug("Connection initiated") LOG.debug("Connection initiated")
@ -535,6 +534,17 @@ class Controller(pyngus.ConnectionEventHandler):
reason or "no reason given") reason or "no reason given")
self._socket_connection.connection.close() self._socket_connection.connection.close()
def sasl_done(self, connection, pn_sasl, outcome):
"""This is a Pyngus callback invoked by Pyngus when the SASL handshake
has completed. The outcome of the handshake will be OK on success or
AUTH on failure.
"""
if outcome == proton.SASL.AUTH:
LOG.error("Unable to connect to %s:%s, authentication failure.",
self.hosts.current.hostname, self.hosts.current.port)
# requires user intervention, treat it like a connection failure:
self._handle_connection_loss()
def _complete_shutdown(self): def _complete_shutdown(self):
"""The AMQP Connection has closed, and the driver shutdown is complete. """The AMQP Connection has closed, and the driver shutdown is complete.
Clean up controller resources and exit. Clean up controller resources and exit.
@ -574,6 +584,6 @@ class Controller(pyngus.ConnectionEventHandler):
self._reconnecting = False self._reconnecting = False
self._senders = {} self._senders = {}
self._socket_connection.reset() self._socket_connection.reset()
hostname, port = self.hosts.next() host = self.hosts.next()
LOG.info("Reconnecting to: %s:%i", hostname, port) LOG.info("Reconnecting to: %s:%i", host.hostname, host.port)
self._socket_connection.connect(hostname, port) self._socket_connection.connect(host)

View File

@ -190,11 +190,9 @@ class ProtonDriver(base.BaseDriver):
super(ProtonDriver, self).__init__(conf, url, default_exchange, super(ProtonDriver, self).__init__(conf, url, default_exchange,
allowed_remote_exmods) allowed_remote_exmods)
# TODO(grs): handle authentication etc
hosts = [(h.hostname, h.port or 5672) for h in url.hosts]
# Create a Controller that connects to the messaging service: # Create a Controller that connects to the messaging service:
self._ctrl = controller.Controller(hosts, default_exchange, conf) self._ctrl = controller.Controller(url.hosts, default_exchange, conf)
# lazy connection setup - don't cause the controller to connect until # lazy connection setup - don't cause the controller to connect until
# after the first messaging request: # after the first messaging request:

View File

@ -101,12 +101,12 @@ class _SocketConnection():
self._handler.socket_error(str(e)) self._handler.socket_error(str(e))
return pyngus.Connection.EOS return pyngus.Connection.EOS
def connect(self, hostname, port, sasl_mechanisms="ANONYMOUS"): def connect(self, host):
"""Connect to host:port and start the AMQP protocol.""" """Connect to host:port and start the AMQP protocol."""
addr = socket.getaddrinfo(hostname, port, addr = socket.getaddrinfo(host.hostname, host.port,
socket.AF_INET, socket.SOCK_STREAM) socket.AF_INET, socket.SOCK_STREAM)
if not addr: if not addr:
key = "%s:%i" % (hostname, port) key = "%s:%i" % (host.hostname, host.port)
error = "Invalid peer address '%s'" % key error = "Invalid peer address '%s'" % key
LOG.error(error) LOG.error(error)
self._handler.socket_error(error) self._handler.socket_error(error)
@ -124,9 +124,14 @@ class _SocketConnection():
return return
self.socket = my_socket self.socket = my_socket
if sasl_mechanisms: # determine the proper SASL mechanism: PLAIN if a username/password is
pn_sasl = self.connection.pn_sasl # present, else ANONYMOUS
pn_sasl.mechanisms(sasl_mechanisms) pn_sasl = self.connection.pn_sasl
if host.username:
password = host.password if host.password else ""
pn_sasl.plain(host.username, password)
else:
pn_sasl.mechanisms("ANONYMOUS")
# TODO(kgiusti): server if accepting inbound connections # TODO(kgiusti): server if accepting inbound connections
pn_sasl.client() pn_sasl.client()
self.connection.open() self.connection.open()
@ -259,10 +264,9 @@ class Thread(threading.Thread):
LOG.info("eventloop shutdown requested") LOG.info("eventloop shutdown requested")
self._shutdown = True self._shutdown = True
def connect(self, hostname, port, handler, properties=None, name=None, def connect(self, host, handler, properties=None, name=None):
sasl_mechanisms="ANONYMOUS"):
"""Get a _SocketConnection to a peer represented by url.""" """Get a _SocketConnection to a peer represented by url."""
key = name or "%s:%i" % (hostname, port) key = name or "%s:%i" % (host.hostname, host.port)
# return pre-existing # return pre-existing
conn = self._container.get_connection(key) conn = self._container.get_connection(key)
if conn: if conn:
@ -273,7 +277,7 @@ class Thread(threading.Thread):
# no name was provided, the host:port combination # no name was provided, the host:port combination
sc = _SocketConnection(key, self._container, sc = _SocketConnection(key, self._container,
properties, handler=handler) properties, handler=handler)
sc.connect(hostname, port, sasl_mechanisms) sc.connect(host)
return sc return sc
def run(self): def run(self):

View File

@ -282,6 +282,62 @@ class TestAmqpNotification(_AmqpBrokerTestCase):
driver.cleanup() driver.cleanup()
@testtools.skipUnless(pyngus, "proton modules not present")
class TestAuthentication(test_utils.BaseTestCase):
def setUp(self):
super(TestAuthentication, self).setUp()
LOG.error("Starting Authentication Test")
# for simplicity, encode the credentials as they would appear 'on the
# wire' in a SASL frame - username and password prefixed by zero.
user_credentials = ["\0joe\0secret"]
self._broker = FakeBroker(sasl_mechanisms="PLAIN",
user_credentials=user_credentials)
self._broker.start()
def tearDown(self):
super(TestAuthentication, self).tearDown()
self._broker.stop()
LOG.error("Authentication Test Ended")
def test_authentication_ok(self):
"""Verify that username and password given in TransportHost are
accepted by the broker.
"""
addr = "amqp://joe:secret@%s:%d" % (self._broker.host,
self._broker.port)
url = messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url)
target = messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1)
rc = driver.send(target, {"context": True},
{"method": "echo"}, wait_for_reply=True)
self.assertIsNotNone(rc)
listener.join(timeout=30)
self.assertFalse(listener.isAlive())
driver.cleanup()
def test_authentication_failure(self):
"""Verify that a bad password given in TransportHost is
rejected by the broker.
"""
addr = "amqp://joe:badpass@%s:%d" % (self._broker.host,
self._broker.port)
url = messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url)
target = messaging.Target(topic="test-topic")
_ListenerThread(driver.listen(target), 1)
self.assertRaises(messaging.MessagingTimeout,
driver.send,
target, {"context": True},
{"method": "echo"},
wait_for_reply=True,
timeout=2.0)
driver.cleanup()
@testtools.skipUnless(pyngus, "proton modules not present") @testtools.skipUnless(pyngus, "proton modules not present")
class TestFailover(test_utils.BaseTestCase): class TestFailover(test_utils.BaseTestCase):
@ -362,7 +418,8 @@ class FakeBroker(threading.Thread):
class Connection(pyngus.ConnectionEventHandler): class Connection(pyngus.ConnectionEventHandler):
"""A single AMQP connection.""" """A single AMQP connection."""
def __init__(self, server, socket_, name): def __init__(self, server, socket_, name,
sasl_mechanisms, user_credentials):
"""Create a Connection using socket_.""" """Create a Connection using socket_."""
self.socket = socket_ self.socket = socket_
self.name = name self.name = name
@ -370,8 +427,11 @@ class FakeBroker(threading.Thread):
self.connection = server.container.create_connection(name, self.connection = server.container.create_connection(name,
self) self)
self.connection.user_context = self self.connection.user_context = self
self.connection.pn_sasl.mechanisms("ANONYMOUS") self.sasl_mechanisms = sasl_mechanisms
self.connection.pn_sasl.server() self.user_credentials = user_credentials
if sasl_mechanisms:
self.connection.pn_sasl.mechanisms(sasl_mechanisms)
self.connection.pn_sasl.server()
self.connection.open() self.connection.open()
self.sender_links = set() self.sender_links = set()
self.closed = False self.closed = False
@ -436,7 +496,14 @@ class FakeBroker(threading.Thread):
link_handle, addr) link_handle, addr)
def sasl_step(self, connection, pn_sasl): def sasl_step(self, connection, pn_sasl):
pn_sasl.done(pn_sasl.OK) # always permit if self.sasl_mechanisms == 'PLAIN':
credentials = pn_sasl.recv()
if not credentials:
return # wait until some arrives
if credentials not in self.user_credentials:
# failed
return pn_sasl.done(pn_sasl.AUTH)
pn_sasl.done(pn_sasl.OK)
class SenderLink(pyngus.SenderEventHandler): class SenderLink(pyngus.SenderEventHandler):
"""An AMQP sending link.""" """An AMQP sending link."""
@ -513,7 +580,9 @@ class FakeBroker(threading.Thread):
broadcast_prefix="broadcast", broadcast_prefix="broadcast",
group_prefix="unicast", group_prefix="unicast",
address_separator=".", address_separator=".",
sock_addr="", sock_port=0): sock_addr="", sock_port=0,
sasl_mechanisms="ANONYMOUS",
user_credentials=None):
"""Create a fake broker listening on sock_addr:sock_port.""" """Create a fake broker listening on sock_addr:sock_port."""
if not pyngus: if not pyngus:
raise AssertionError("pyngus module not present") raise AssertionError("pyngus module not present")
@ -522,6 +591,8 @@ class FakeBroker(threading.Thread):
self._broadcast_prefix = broadcast_prefix + address_separator self._broadcast_prefix = broadcast_prefix + address_separator
self._group_prefix = group_prefix + address_separator self._group_prefix = group_prefix + address_separator
self._address_separator = address_separator self._address_separator = address_separator
self._sasl_mechanisms = sasl_mechanisms
self._user_credentials = user_credentials
self._wakeup_pipe = os.pipe() self._wakeup_pipe = os.pipe()
self._my_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._my_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._my_socket.bind((sock_addr, sock_port)) self._my_socket.bind((sock_addr, sock_port))
@ -581,7 +652,9 @@ class FakeBroker(threading.Thread):
# create a new Connection for it: # create a new Connection for it:
client_socket, client_address = self._my_socket.accept() client_socket, client_address = self._my_socket.accept()
name = str(client_address) name = str(client_address)
conn = FakeBroker.Connection(self, client_socket, name) conn = FakeBroker.Connection(self, client_socket, name,
self._sasl_mechanisms,
self._user_credentials)
self._connections[conn.name] = conn self._connections[conn.name] = conn
elif r is self._wakeup_pipe[0]: elif r is self._wakeup_pipe[0]:
os.read(self._wakeup_pipe[0], 512) os.read(self._wakeup_pipe[0], 512)