Merge "Enable user authentication in the AMQP 1.0 driver"
This commit is contained in:
commit
1a53d3ed30
@ -25,17 +25,18 @@ functions scheduled by the Controller.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import collections
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
import proton
|
||||||
import pyngus
|
import pyngus
|
||||||
from six import moves
|
from six import moves
|
||||||
|
|
||||||
from oslo.config import cfg
|
from oslo.config import cfg
|
||||||
from oslo.messaging._drivers.protocols.amqp import eventloop
|
from oslo.messaging._drivers.protocols.amqp import eventloop
|
||||||
from oslo.messaging._drivers.protocols.amqp import opts
|
from oslo.messaging._drivers.protocols.amqp import opts
|
||||||
|
from oslo.messaging import transport
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -188,25 +189,24 @@ class Server(pyngus.ReceiverEventHandler):
|
|||||||
|
|
||||||
|
|
||||||
class Hosts(object):
|
class Hosts(object):
|
||||||
"""An order list of peer addresses. Connection failover progresses from
|
"""An order list of TransportHost addresses. Connection failover
|
||||||
one host to the next.
|
progresses from one host to the next.
|
||||||
"""
|
"""
|
||||||
HostnamePort = collections.namedtuple('HostnamePort',
|
|
||||||
['hostname', 'port'])
|
|
||||||
|
|
||||||
def __init__(self, entries=None):
|
def __init__(self, entries=None):
|
||||||
self._entries = [self.HostnamePort(h, p) for h, p in entries or []]
|
self._entries = entries[:] if entries else []
|
||||||
|
for entry in self._entries:
|
||||||
|
entry.port = entry.port or 5672
|
||||||
self._current = 0
|
self._current = 0
|
||||||
|
|
||||||
def add(self, hostname, port=5672):
|
def add(self, transport_host):
|
||||||
self._entries.append(self.HostnamePort(hostname, port))
|
self._entries.append(transport_host)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current(self):
|
def current(self):
|
||||||
if len(self._entries):
|
if len(self._entries):
|
||||||
return self._entries[self._current]
|
return self._entries[self._current]
|
||||||
else:
|
else:
|
||||||
return self.HostnamePort("localhost", 5672)
|
return transport.TransportHost(hostname="localhost", port=5672)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
if len(self._entries) > 1:
|
if len(self._entries) > 1:
|
||||||
@ -217,7 +217,7 @@ class Hosts(object):
|
|||||||
return '<Hosts ' + str(self) + '>'
|
return '<Hosts ' + str(self) + '>'
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return ", ".join(["%s:%i" % e for e in self._entries])
|
return ", ".join(["%r" % th for th in self._entries])
|
||||||
|
|
||||||
|
|
||||||
class Controller(pyngus.ConnectionEventHandler):
|
class Controller(pyngus.ConnectionEventHandler):
|
||||||
@ -394,8 +394,7 @@ class Controller(pyngus.ConnectionEventHandler):
|
|||||||
|
|
||||||
def _do_connect(self):
|
def _do_connect(self):
|
||||||
"""Establish connection and reply subscription on processor thread."""
|
"""Establish connection and reply subscription on processor thread."""
|
||||||
hostname = self.hosts.current.hostname
|
host = self.hosts.current
|
||||||
port = self.hosts.current.port
|
|
||||||
conn_props = {}
|
conn_props = {}
|
||||||
if self.idle_timeout:
|
if self.idle_timeout:
|
||||||
conn_props["idle-time-out"] = float(self.idle_timeout)
|
conn_props["idle-time-out"] = float(self.idle_timeout)
|
||||||
@ -412,7 +411,7 @@ class Controller(pyngus.ConnectionEventHandler):
|
|||||||
self.ssl_key_file,
|
self.ssl_key_file,
|
||||||
self.ssl_key_password)
|
self.ssl_key_password)
|
||||||
conn_props["x-ssl-allow-cleartext"] = self.ssl_allow_insecure
|
conn_props["x-ssl-allow-cleartext"] = self.ssl_allow_insecure
|
||||||
self._socket_connection = self.processor.connect(hostname, port,
|
self._socket_connection = self.processor.connect(host,
|
||||||
handler=self,
|
handler=self,
|
||||||
properties=conn_props)
|
properties=conn_props)
|
||||||
LOG.debug("Connection initiated")
|
LOG.debug("Connection initiated")
|
||||||
@ -535,6 +534,17 @@ class Controller(pyngus.ConnectionEventHandler):
|
|||||||
reason or "no reason given")
|
reason or "no reason given")
|
||||||
self._socket_connection.connection.close()
|
self._socket_connection.connection.close()
|
||||||
|
|
||||||
|
def sasl_done(self, connection, pn_sasl, outcome):
|
||||||
|
"""This is a Pyngus callback invoked by Pyngus when the SASL handshake
|
||||||
|
has completed. The outcome of the handshake will be OK on success or
|
||||||
|
AUTH on failure.
|
||||||
|
"""
|
||||||
|
if outcome == proton.SASL.AUTH:
|
||||||
|
LOG.error("Unable to connect to %s:%s, authentication failure.",
|
||||||
|
self.hosts.current.hostname, self.hosts.current.port)
|
||||||
|
# requires user intervention, treat it like a connection failure:
|
||||||
|
self._handle_connection_loss()
|
||||||
|
|
||||||
def _complete_shutdown(self):
|
def _complete_shutdown(self):
|
||||||
"""The AMQP Connection has closed, and the driver shutdown is complete.
|
"""The AMQP Connection has closed, and the driver shutdown is complete.
|
||||||
Clean up controller resources and exit.
|
Clean up controller resources and exit.
|
||||||
@ -574,6 +584,6 @@ class Controller(pyngus.ConnectionEventHandler):
|
|||||||
self._reconnecting = False
|
self._reconnecting = False
|
||||||
self._senders = {}
|
self._senders = {}
|
||||||
self._socket_connection.reset()
|
self._socket_connection.reset()
|
||||||
hostname, port = self.hosts.next()
|
host = self.hosts.next()
|
||||||
LOG.info("Reconnecting to: %s:%i", hostname, port)
|
LOG.info("Reconnecting to: %s:%i", host.hostname, host.port)
|
||||||
self._socket_connection.connect(hostname, port)
|
self._socket_connection.connect(host)
|
||||||
|
@ -190,11 +190,9 @@ class ProtonDriver(base.BaseDriver):
|
|||||||
|
|
||||||
super(ProtonDriver, self).__init__(conf, url, default_exchange,
|
super(ProtonDriver, self).__init__(conf, url, default_exchange,
|
||||||
allowed_remote_exmods)
|
allowed_remote_exmods)
|
||||||
# TODO(grs): handle authentication etc
|
|
||||||
hosts = [(h.hostname, h.port or 5672) for h in url.hosts]
|
|
||||||
|
|
||||||
# Create a Controller that connects to the messaging service:
|
# Create a Controller that connects to the messaging service:
|
||||||
self._ctrl = controller.Controller(hosts, default_exchange, conf)
|
self._ctrl = controller.Controller(url.hosts, default_exchange, conf)
|
||||||
|
|
||||||
# lazy connection setup - don't cause the controller to connect until
|
# lazy connection setup - don't cause the controller to connect until
|
||||||
# after the first messaging request:
|
# after the first messaging request:
|
||||||
|
@ -101,12 +101,12 @@ class _SocketConnection():
|
|||||||
self._handler.socket_error(str(e))
|
self._handler.socket_error(str(e))
|
||||||
return pyngus.Connection.EOS
|
return pyngus.Connection.EOS
|
||||||
|
|
||||||
def connect(self, hostname, port, sasl_mechanisms="ANONYMOUS"):
|
def connect(self, host):
|
||||||
"""Connect to host:port and start the AMQP protocol."""
|
"""Connect to host:port and start the AMQP protocol."""
|
||||||
addr = socket.getaddrinfo(hostname, port,
|
addr = socket.getaddrinfo(host.hostname, host.port,
|
||||||
socket.AF_INET, socket.SOCK_STREAM)
|
socket.AF_INET, socket.SOCK_STREAM)
|
||||||
if not addr:
|
if not addr:
|
||||||
key = "%s:%i" % (hostname, port)
|
key = "%s:%i" % (host.hostname, host.port)
|
||||||
error = "Invalid peer address '%s'" % key
|
error = "Invalid peer address '%s'" % key
|
||||||
LOG.error(error)
|
LOG.error(error)
|
||||||
self._handler.socket_error(error)
|
self._handler.socket_error(error)
|
||||||
@ -124,9 +124,14 @@ class _SocketConnection():
|
|||||||
return
|
return
|
||||||
self.socket = my_socket
|
self.socket = my_socket
|
||||||
|
|
||||||
if sasl_mechanisms:
|
# determine the proper SASL mechanism: PLAIN if a username/password is
|
||||||
|
# present, else ANONYMOUS
|
||||||
pn_sasl = self.connection.pn_sasl
|
pn_sasl = self.connection.pn_sasl
|
||||||
pn_sasl.mechanisms(sasl_mechanisms)
|
if host.username:
|
||||||
|
password = host.password if host.password else ""
|
||||||
|
pn_sasl.plain(host.username, password)
|
||||||
|
else:
|
||||||
|
pn_sasl.mechanisms("ANONYMOUS")
|
||||||
# TODO(kgiusti): server if accepting inbound connections
|
# TODO(kgiusti): server if accepting inbound connections
|
||||||
pn_sasl.client()
|
pn_sasl.client()
|
||||||
self.connection.open()
|
self.connection.open()
|
||||||
@ -259,10 +264,9 @@ class Thread(threading.Thread):
|
|||||||
LOG.info("eventloop shutdown requested")
|
LOG.info("eventloop shutdown requested")
|
||||||
self._shutdown = True
|
self._shutdown = True
|
||||||
|
|
||||||
def connect(self, hostname, port, handler, properties=None, name=None,
|
def connect(self, host, handler, properties=None, name=None):
|
||||||
sasl_mechanisms="ANONYMOUS"):
|
|
||||||
"""Get a _SocketConnection to a peer represented by url."""
|
"""Get a _SocketConnection to a peer represented by url."""
|
||||||
key = name or "%s:%i" % (hostname, port)
|
key = name or "%s:%i" % (host.hostname, host.port)
|
||||||
# return pre-existing
|
# return pre-existing
|
||||||
conn = self._container.get_connection(key)
|
conn = self._container.get_connection(key)
|
||||||
if conn:
|
if conn:
|
||||||
@ -273,7 +277,7 @@ class Thread(threading.Thread):
|
|||||||
# no name was provided, the host:port combination
|
# no name was provided, the host:port combination
|
||||||
sc = _SocketConnection(key, self._container,
|
sc = _SocketConnection(key, self._container,
|
||||||
properties, handler=handler)
|
properties, handler=handler)
|
||||||
sc.connect(hostname, port, sasl_mechanisms)
|
sc.connect(host)
|
||||||
return sc
|
return sc
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
@ -282,6 +282,62 @@ class TestAmqpNotification(_AmqpBrokerTestCase):
|
|||||||
driver.cleanup()
|
driver.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@testtools.skipUnless(pyngus, "proton modules not present")
|
||||||
|
class TestAuthentication(test_utils.BaseTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(TestAuthentication, self).setUp()
|
||||||
|
LOG.error("Starting Authentication Test")
|
||||||
|
# for simplicity, encode the credentials as they would appear 'on the
|
||||||
|
# wire' in a SASL frame - username and password prefixed by zero.
|
||||||
|
user_credentials = ["\0joe\0secret"]
|
||||||
|
self._broker = FakeBroker(sasl_mechanisms="PLAIN",
|
||||||
|
user_credentials=user_credentials)
|
||||||
|
self._broker.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super(TestAuthentication, self).tearDown()
|
||||||
|
self._broker.stop()
|
||||||
|
LOG.error("Authentication Test Ended")
|
||||||
|
|
||||||
|
def test_authentication_ok(self):
|
||||||
|
"""Verify that username and password given in TransportHost are
|
||||||
|
accepted by the broker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
addr = "amqp://joe:secret@%s:%d" % (self._broker.host,
|
||||||
|
self._broker.port)
|
||||||
|
url = messaging.TransportURL.parse(self.conf, addr)
|
||||||
|
driver = amqp_driver.ProtonDriver(self.conf, url)
|
||||||
|
target = messaging.Target(topic="test-topic")
|
||||||
|
listener = _ListenerThread(driver.listen(target), 1)
|
||||||
|
rc = driver.send(target, {"context": True},
|
||||||
|
{"method": "echo"}, wait_for_reply=True)
|
||||||
|
self.assertIsNotNone(rc)
|
||||||
|
listener.join(timeout=30)
|
||||||
|
self.assertFalse(listener.isAlive())
|
||||||
|
driver.cleanup()
|
||||||
|
|
||||||
|
def test_authentication_failure(self):
|
||||||
|
"""Verify that a bad password given in TransportHost is
|
||||||
|
rejected by the broker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
addr = "amqp://joe:badpass@%s:%d" % (self._broker.host,
|
||||||
|
self._broker.port)
|
||||||
|
url = messaging.TransportURL.parse(self.conf, addr)
|
||||||
|
driver = amqp_driver.ProtonDriver(self.conf, url)
|
||||||
|
target = messaging.Target(topic="test-topic")
|
||||||
|
_ListenerThread(driver.listen(target), 1)
|
||||||
|
self.assertRaises(messaging.MessagingTimeout,
|
||||||
|
driver.send,
|
||||||
|
target, {"context": True},
|
||||||
|
{"method": "echo"},
|
||||||
|
wait_for_reply=True,
|
||||||
|
timeout=2.0)
|
||||||
|
driver.cleanup()
|
||||||
|
|
||||||
|
|
||||||
@testtools.skipUnless(pyngus, "proton modules not present")
|
@testtools.skipUnless(pyngus, "proton modules not present")
|
||||||
class TestFailover(test_utils.BaseTestCase):
|
class TestFailover(test_utils.BaseTestCase):
|
||||||
|
|
||||||
@ -362,7 +418,8 @@ class FakeBroker(threading.Thread):
|
|||||||
class Connection(pyngus.ConnectionEventHandler):
|
class Connection(pyngus.ConnectionEventHandler):
|
||||||
"""A single AMQP connection."""
|
"""A single AMQP connection."""
|
||||||
|
|
||||||
def __init__(self, server, socket_, name):
|
def __init__(self, server, socket_, name,
|
||||||
|
sasl_mechanisms, user_credentials):
|
||||||
"""Create a Connection using socket_."""
|
"""Create a Connection using socket_."""
|
||||||
self.socket = socket_
|
self.socket = socket_
|
||||||
self.name = name
|
self.name = name
|
||||||
@ -370,7 +427,10 @@ class FakeBroker(threading.Thread):
|
|||||||
self.connection = server.container.create_connection(name,
|
self.connection = server.container.create_connection(name,
|
||||||
self)
|
self)
|
||||||
self.connection.user_context = self
|
self.connection.user_context = self
|
||||||
self.connection.pn_sasl.mechanisms("ANONYMOUS")
|
self.sasl_mechanisms = sasl_mechanisms
|
||||||
|
self.user_credentials = user_credentials
|
||||||
|
if sasl_mechanisms:
|
||||||
|
self.connection.pn_sasl.mechanisms(sasl_mechanisms)
|
||||||
self.connection.pn_sasl.server()
|
self.connection.pn_sasl.server()
|
||||||
self.connection.open()
|
self.connection.open()
|
||||||
self.sender_links = set()
|
self.sender_links = set()
|
||||||
@ -436,7 +496,14 @@ class FakeBroker(threading.Thread):
|
|||||||
link_handle, addr)
|
link_handle, addr)
|
||||||
|
|
||||||
def sasl_step(self, connection, pn_sasl):
|
def sasl_step(self, connection, pn_sasl):
|
||||||
pn_sasl.done(pn_sasl.OK) # always permit
|
if self.sasl_mechanisms == 'PLAIN':
|
||||||
|
credentials = pn_sasl.recv()
|
||||||
|
if not credentials:
|
||||||
|
return # wait until some arrives
|
||||||
|
if credentials not in self.user_credentials:
|
||||||
|
# failed
|
||||||
|
return pn_sasl.done(pn_sasl.AUTH)
|
||||||
|
pn_sasl.done(pn_sasl.OK)
|
||||||
|
|
||||||
class SenderLink(pyngus.SenderEventHandler):
|
class SenderLink(pyngus.SenderEventHandler):
|
||||||
"""An AMQP sending link."""
|
"""An AMQP sending link."""
|
||||||
@ -513,7 +580,9 @@ class FakeBroker(threading.Thread):
|
|||||||
broadcast_prefix="broadcast",
|
broadcast_prefix="broadcast",
|
||||||
group_prefix="unicast",
|
group_prefix="unicast",
|
||||||
address_separator=".",
|
address_separator=".",
|
||||||
sock_addr="", sock_port=0):
|
sock_addr="", sock_port=0,
|
||||||
|
sasl_mechanisms="ANONYMOUS",
|
||||||
|
user_credentials=None):
|
||||||
"""Create a fake broker listening on sock_addr:sock_port."""
|
"""Create a fake broker listening on sock_addr:sock_port."""
|
||||||
if not pyngus:
|
if not pyngus:
|
||||||
raise AssertionError("pyngus module not present")
|
raise AssertionError("pyngus module not present")
|
||||||
@ -522,6 +591,8 @@ class FakeBroker(threading.Thread):
|
|||||||
self._broadcast_prefix = broadcast_prefix + address_separator
|
self._broadcast_prefix = broadcast_prefix + address_separator
|
||||||
self._group_prefix = group_prefix + address_separator
|
self._group_prefix = group_prefix + address_separator
|
||||||
self._address_separator = address_separator
|
self._address_separator = address_separator
|
||||||
|
self._sasl_mechanisms = sasl_mechanisms
|
||||||
|
self._user_credentials = user_credentials
|
||||||
self._wakeup_pipe = os.pipe()
|
self._wakeup_pipe = os.pipe()
|
||||||
self._my_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self._my_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self._my_socket.bind((sock_addr, sock_port))
|
self._my_socket.bind((sock_addr, sock_port))
|
||||||
@ -581,7 +652,9 @@ class FakeBroker(threading.Thread):
|
|||||||
# create a new Connection for it:
|
# create a new Connection for it:
|
||||||
client_socket, client_address = self._my_socket.accept()
|
client_socket, client_address = self._my_socket.accept()
|
||||||
name = str(client_address)
|
name = str(client_address)
|
||||||
conn = FakeBroker.Connection(self, client_socket, name)
|
conn = FakeBroker.Connection(self, client_socket, name,
|
||||||
|
self._sasl_mechanisms,
|
||||||
|
self._user_credentials)
|
||||||
self._connections[conn.name] = conn
|
self._connections[conn.name] = conn
|
||||||
elif r is self._wakeup_pipe[0]:
|
elif r is self._wakeup_pipe[0]:
|
||||||
os.read(self._wakeup_pipe[0], 512)
|
os.read(self._wakeup_pipe[0], 512)
|
||||||
|
Loading…
Reference in New Issue
Block a user