Use callbacks for sasl handshake request / response
This commit is contained in:
121
kafka/conn.py
121
kafka/conn.py
@@ -74,10 +74,11 @@ class BrokerConnection(object):
|
||||
'ssl_password': None,
|
||||
'api_version': (0, 8, 2), # default to most restrictive
|
||||
'state_change_callback': lambda conn: True,
|
||||
'sasl_mechanism': None,
|
||||
'sasl_mechanism': 'PLAIN',
|
||||
'sasl_plain_username': None,
|
||||
'sasl_plain_password': None
|
||||
}
|
||||
SASL_MECHANISMS = ('PLAIN',)
|
||||
|
||||
def __init__(self, host, port, afi, **configs):
|
||||
self.host = host
|
||||
@@ -100,11 +101,19 @@ class BrokerConnection(object):
|
||||
(socket.SOL_SOCKET, socket.SO_SNDBUF,
|
||||
self.config['send_buffer_bytes']))
|
||||
|
||||
if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'):
|
||||
assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, (
|
||||
'sasl_mechanism must be in ' + self.SASL_MECHANISMS)
|
||||
if self.config['sasl_mechanism'] == 'PLAIN':
|
||||
assert self.config['sasl_plain_username'] is not None, 'sasl_plain_username required for PLAIN sasl'
|
||||
assert self.config['sasl_plain_password'] is not None, 'sasl_plain_password required for PLAIN sasl'
|
||||
|
||||
self.state = ConnectionStates.DISCONNECTED
|
||||
self._sock = None
|
||||
self._ssl_context = None
|
||||
if self.config['ssl_context'] is not None:
|
||||
self._ssl_context = self.config['ssl_context']
|
||||
self._sasl_auth_future = None
|
||||
self._rbuffer = io.BytesIO()
|
||||
self._receiving = False
|
||||
self._next_payload_bytes = 0
|
||||
@@ -224,8 +233,9 @@ class BrokerConnection(object):
|
||||
self.config['state_change_callback'](self)
|
||||
|
||||
if self.state is ConnectionStates.AUTHENTICATING:
|
||||
assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
|
||||
if self._try_authenticate():
|
||||
log.debug('%s: Authenticated as %s', str(self), self.config['sasl_plain_username'])
|
||||
log.info('%s: Authenticated as %s', str(self), self.config['sasl_plain_username'])
|
||||
self.state = ConnectionStates.CONNECTED
|
||||
self.config['state_change_callback'](self)
|
||||
|
||||
@@ -289,58 +299,44 @@ class BrokerConnection(object):
|
||||
return False
|
||||
|
||||
def _try_authenticate(self):
|
||||
assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
|
||||
assert self.config['api_version'] >= (0, 10) or self.config['api_version'] is None
|
||||
|
||||
if self._sasl_auth_future is None:
|
||||
# Build a SaslHandShakeRequest message
|
||||
request = SaslHandShakeRequest[0](self.config['sasl_mechanism'])
|
||||
future = Future()
|
||||
sasl_response = self._send(request)
|
||||
sasl_response.add_callback(self._handle_sasl_handshake_response, future)
|
||||
sasl_response.add_errback(lambda f, e: f.failure(e), future)
|
||||
self._sasl_auth_future = future
|
||||
self._recv()
|
||||
if self._sasl_auth_future.failed():
|
||||
raise self._sasl_auth_future.exception
|
||||
return self._sasl_auth_future.succeeded()
|
||||
|
||||
def _handle_sasl_handshake_response(self, future, response):
|
||||
error_type = Errors.for_code(response.error_code)
|
||||
if error_type is not Errors.NoError:
|
||||
error = error_type(self)
|
||||
self.close(error=error)
|
||||
return future.failure(error_type(self))
|
||||
|
||||
if self.config['sasl_mechanism'] == 'PLAIN':
|
||||
return self._try_authenticate_plain(future)
|
||||
else:
|
||||
return future.failure(
|
||||
Errors.UnsupportedSaslMechanismError(
|
||||
'kafka-python does not support SASL mechanism %s' %
|
||||
self.config['sasl_mechanism']))
|
||||
|
||||
def _try_authenticate_plain(self, future):
|
||||
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
|
||||
log.warning('%s: Sending username and password in the clear', str(self))
|
||||
|
||||
# Build a SaslHandShakeRequest message
|
||||
correlation_id = self._next_correlation_id()
|
||||
request = SaslHandShakeRequest[0](self.config['sasl_mechanism'])
|
||||
header = RequestHeader(request,
|
||||
correlation_id=correlation_id,
|
||||
client_id=self.config['client_id'])
|
||||
|
||||
message = b''.join([header.encode(), request.encode()])
|
||||
size = Int32.encode(len(message))
|
||||
|
||||
# Attempt to send it over our socket
|
||||
try:
|
||||
self._sock.setblocking(True)
|
||||
self._sock.sendall(size + message)
|
||||
self._sock.setblocking(False)
|
||||
except (AssertionError, ConnectionError) as e:
|
||||
log.exception("Error sending %s to %s", request, self)
|
||||
error = Errors.ConnectionError("%s: %s" % (str(self), e))
|
||||
self.close(error=error)
|
||||
return False
|
||||
|
||||
future = Future()
|
||||
ifr = InFlightRequest(request=request,
|
||||
correlation_id=correlation_id,
|
||||
response_type=request.RESPONSE_TYPE,
|
||||
future=future,
|
||||
timestamp=time.time())
|
||||
self.in_flight_requests.append(ifr)
|
||||
|
||||
# Listen for a reply and check that the server supports the PLAIN mechanism
|
||||
response = None
|
||||
while not response:
|
||||
response = self.recv()
|
||||
|
||||
if not response.error_code is 0:
|
||||
raise Errors.for_code(response.error_code)
|
||||
|
||||
if not self.config['sasl_mechanism'] in response.enabled_mechanisms:
|
||||
raise Errors.AuthenticationMethodNotSupported(self.config['sasl_mechanism'] + " is not supported by broker")
|
||||
|
||||
return self._try_authenticate_plain()
|
||||
|
||||
def _try_authenticate_plain(self):
|
||||
data = b''
|
||||
try:
|
||||
self._sock.setblocking(True)
|
||||
# Send our credentials
|
||||
# Send PLAIN credentials per RFC-4616
|
||||
msg = bytes('\0'.join([self.config['sasl_plain_username'],
|
||||
self.config['sasl_plain_username'],
|
||||
self.config['sasl_plain_password']]).encode('utf-8'))
|
||||
@@ -351,26 +347,26 @@ class BrokerConnection(object):
|
||||
# The connection is closed on failure
|
||||
received_bytes = 0
|
||||
while received_bytes < 4:
|
||||
data = data + self._sock.recv(4 - received_bytes)
|
||||
received_bytes = received_bytes + len(data)
|
||||
data += self._sock.recv(4 - received_bytes)
|
||||
received_bytes += len(data)
|
||||
if not data:
|
||||
log.error('%s: Authentication failed for user %s', self, self.config['sasl_plain_username'])
|
||||
self.close(error=Errors.ConnectionError('Authentication failed'))
|
||||
raise Errors.AuthenticationFailedError('Authentication failed for user {}'.format(self.config['sasl_plain_username']))
|
||||
error = Errors.AuthenticationFailedError(
|
||||
'Authentication failed for user {0}'.format(
|
||||
self.config['sasl_plain_username']))
|
||||
future.failure(error)
|
||||
raise error
|
||||
self._sock.setblocking(False)
|
||||
except (AssertionError, ConnectionError) as e:
|
||||
log.exception("%s: Error receiving reply from server", self)
|
||||
error = Errors.ConnectionError("%s: %s" % (str(self), e))
|
||||
future.failure(error)
|
||||
self.close(error=error)
|
||||
return False
|
||||
|
||||
with io.BytesIO() as buffer:
|
||||
buffer.write(data)
|
||||
buffer.seek(0)
|
||||
if not Int32.decode(buffer) == 0:
|
||||
raise Errors.KafkaError('Expected a zero sized reply after sending credentials')
|
||||
if data != '\x00\x00\x00\x00':
|
||||
return future.failure(Errors.AuthenticationFailedError())
|
||||
|
||||
return True
|
||||
return future.success(True)
|
||||
|
||||
def blacked_out(self):
|
||||
"""
|
||||
@@ -437,6 +433,10 @@ class BrokerConnection(object):
|
||||
return future.failure(Errors.ConnectionError(str(self)))
|
||||
elif not self.can_send_more():
|
||||
return future.failure(Errors.TooManyInFlightRequests(str(self)))
|
||||
return self._send(request, expect_response=expect_response)
|
||||
|
||||
def _send(self, request, expect_response=True):
|
||||
future = Future()
|
||||
correlation_id = self._next_correlation_id()
|
||||
header = RequestHeader(request,
|
||||
correlation_id=correlation_id,
|
||||
@@ -505,6 +505,9 @@ class BrokerConnection(object):
|
||||
self.config['request_timeout_ms']))
|
||||
return None
|
||||
|
||||
return self._recv()
|
||||
|
||||
def _recv(self):
|
||||
# Not receiving is the state of reading the payload header
|
||||
if not self._receiving:
|
||||
try:
|
||||
@@ -552,7 +555,7 @@ class BrokerConnection(object):
|
||||
# enough data to read the full bytes_to_read
|
||||
# but if the socket is disconnected, we will get empty data
|
||||
# without an exception raised
|
||||
if not data:
|
||||
if bytes_to_read and not data:
|
||||
log.error('%s: socket disconnected', self)
|
||||
self.close(error=Errors.ConnectionError('socket disconnected'))
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user