Use callbacks for sasl handshake request / response

This commit is contained in:
Dana Powers
2016-08-03 11:45:50 -07:00
parent 6b801a8d2e
commit 2b2c72feac

View File

@@ -74,10 +74,11 @@ class BrokerConnection(object):
'ssl_password': None,
'api_version': (0, 8, 2), # default to most restrictive
'state_change_callback': lambda conn: True,
'sasl_mechanism': None,
'sasl_mechanism': 'PLAIN',
'sasl_plain_username': None,
'sasl_plain_password': None
}
SASL_MECHANISMS = ('PLAIN',)
def __init__(self, host, port, afi, **configs):
self.host = host
@@ -100,11 +101,19 @@ class BrokerConnection(object):
(socket.SOL_SOCKET, socket.SO_SNDBUF,
self.config['send_buffer_bytes']))
if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'):
assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, (
'sasl_mechanism must be in ' + self.SASL_MECHANISMS)
if self.config['sasl_mechanism'] == 'PLAIN':
assert self.config['sasl_plain_username'] is not None, 'sasl_plain_username required for PLAIN sasl'
assert self.config['sasl_plain_password'] is not None, 'sasl_plain_password required for PLAIN sasl'
self.state = ConnectionStates.DISCONNECTED
self._sock = None
self._ssl_context = None
if self.config['ssl_context'] is not None:
self._ssl_context = self.config['ssl_context']
self._sasl_auth_future = None
self._rbuffer = io.BytesIO()
self._receiving = False
self._next_payload_bytes = 0
@@ -224,8 +233,9 @@ class BrokerConnection(object):
self.config['state_change_callback'](self)
if self.state is ConnectionStates.AUTHENTICATING:
assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
if self._try_authenticate():
log.debug('%s: Authenticated as %s', str(self), self.config['sasl_plain_username'])
log.info('%s: Authenticated as %s', str(self), self.config['sasl_plain_username'])
self.state = ConnectionStates.CONNECTED
self.config['state_change_callback'](self)
@@ -289,58 +299,44 @@ class BrokerConnection(object):
return False
def _try_authenticate(self):
assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
assert self.config['api_version'] >= (0, 10) or self.config['api_version'] is None
if self._sasl_auth_future is None:
# Build a SaslHandShakeRequest message
request = SaslHandShakeRequest[0](self.config['sasl_mechanism'])
future = Future()
sasl_response = self._send(request)
sasl_response.add_callback(self._handle_sasl_handshake_response, future)
sasl_response.add_errback(lambda f, e: f.failure(e), future)
self._sasl_auth_future = future
self._recv()
if self._sasl_auth_future.failed():
raise self._sasl_auth_future.exception
return self._sasl_auth_future.succeeded()
def _handle_sasl_handshake_response(self, future, response):
error_type = Errors.for_code(response.error_code)
if error_type is not Errors.NoError:
error = error_type(self)
self.close(error=error)
return future.failure(error_type(self))
if self.config['sasl_mechanism'] == 'PLAIN':
return self._try_authenticate_plain(future)
else:
return future.failure(
Errors.UnsupportedSaslMechanismError(
'kafka-python does not support SASL mechanism %s' %
self.config['sasl_mechanism']))
def _try_authenticate_plain(self, future):
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
log.warning('%s: Sending username and password in the clear', str(self))
# Build a SaslHandShakeRequest message
correlation_id = self._next_correlation_id()
request = SaslHandShakeRequest[0](self.config['sasl_mechanism'])
header = RequestHeader(request,
correlation_id=correlation_id,
client_id=self.config['client_id'])
message = b''.join([header.encode(), request.encode()])
size = Int32.encode(len(message))
# Attempt to send it over our socket
try:
self._sock.setblocking(True)
self._sock.sendall(size + message)
self._sock.setblocking(False)
except (AssertionError, ConnectionError) as e:
log.exception("Error sending %s to %s", request, self)
error = Errors.ConnectionError("%s: %s" % (str(self), e))
self.close(error=error)
return False
future = Future()
ifr = InFlightRequest(request=request,
correlation_id=correlation_id,
response_type=request.RESPONSE_TYPE,
future=future,
timestamp=time.time())
self.in_flight_requests.append(ifr)
# Listen for a reply and check that the server supports the PLAIN mechanism
response = None
while not response:
response = self.recv()
if not response.error_code is 0:
raise Errors.for_code(response.error_code)
if not self.config['sasl_mechanism'] in response.enabled_mechanisms:
raise Errors.AuthenticationMethodNotSupported(self.config['sasl_mechanism'] + " is not supported by broker")
return self._try_authenticate_plain()
def _try_authenticate_plain(self):
data = b''
try:
self._sock.setblocking(True)
# Send our credentials
# Send PLAIN credentials per RFC-4616
msg = bytes('\0'.join([self.config['sasl_plain_username'],
self.config['sasl_plain_username'],
self.config['sasl_plain_password']]).encode('utf-8'))
@@ -351,26 +347,26 @@ class BrokerConnection(object):
# The connection is closed on failure
received_bytes = 0
while received_bytes < 4:
data = data + self._sock.recv(4 - received_bytes)
received_bytes = received_bytes + len(data)
data += self._sock.recv(4 - received_bytes)
received_bytes += len(data)
if not data:
log.error('%s: Authentication failed for user %s', self, self.config['sasl_plain_username'])
self.close(error=Errors.ConnectionError('Authentication failed'))
raise Errors.AuthenticationFailedError('Authentication failed for user {}'.format(self.config['sasl_plain_username']))
error = Errors.AuthenticationFailedError(
'Authentication failed for user {0}'.format(
self.config['sasl_plain_username']))
future.failure(error)
raise error
self._sock.setblocking(False)
except (AssertionError, ConnectionError) as e:
log.exception("%s: Error receiving reply from server", self)
error = Errors.ConnectionError("%s: %s" % (str(self), e))
future.failure(error)
self.close(error=error)
return False
with io.BytesIO() as buffer:
buffer.write(data)
buffer.seek(0)
if not Int32.decode(buffer) == 0:
raise Errors.KafkaError('Expected a zero sized reply after sending credentials')
if data != '\x00\x00\x00\x00':
return future.failure(Errors.AuthenticationFailedError())
return True
return future.success(True)
def blacked_out(self):
"""
@@ -437,6 +433,10 @@ class BrokerConnection(object):
return future.failure(Errors.ConnectionError(str(self)))
elif not self.can_send_more():
return future.failure(Errors.TooManyInFlightRequests(str(self)))
return self._send(request, expect_response=expect_response)
def _send(self, request, expect_response=True):
future = Future()
correlation_id = self._next_correlation_id()
header = RequestHeader(request,
correlation_id=correlation_id,
@@ -505,6 +505,9 @@ class BrokerConnection(object):
self.config['request_timeout_ms']))
return None
return self._recv()
def _recv(self):
# Not receiving is the state of reading the payload header
if not self._receiving:
try:
@@ -552,7 +555,7 @@ class BrokerConnection(object):
# enough data to read the full bytes_to_read
# but if the socket is disconnected, we will get empty data
# without an exception raised
if not data:
if bytes_to_read and not data:
log.error('%s: socket disconnected', self)
self.close(error=Errors.ConnectionError('socket disconnected'))
return None