Add ssl support to BrokerConnection

This commit is contained in:
Dana Powers
2016-04-04 19:50:45 -07:00
parent cda2d59da4
commit ffd1423a81

View File

@@ -5,6 +5,7 @@ import logging
import io import io
from random import shuffle from random import shuffle
import socket import socket
import ssl
import struct import struct
from threading import local from threading import local
import time import time
@@ -29,11 +30,25 @@ log = logging.getLogger(__name__)
DEFAULT_SOCKET_TIMEOUT_SECONDS = 120 DEFAULT_SOCKET_TIMEOUT_SECONDS = 120
DEFAULT_KAFKA_PORT = 9092 DEFAULT_KAFKA_PORT = 9092
# support older ssl libraries
try:
assert ssl.SSLWantReadError
assert ssl.SSLWantWriteError
assert ssl.SSLZeroReturnError
except:
log.warning('old ssl module detected.'
' ssl error handling may not operate cleanly.'
' Consider upgrading to python 3.5 or 2.7')
ssl.SSLWantReadError = ssl.SSLError
ssl.SSLWantWriteError = ssl.SSLError
ssl.SSLZeroReturnError = ssl.SSLError
class ConnectionStates(object): class ConnectionStates(object):
DISCONNECTING = '<disconnecting>' DISCONNECTING = '<disconnecting>'
DISCONNECTED = '<disconnected>' DISCONNECTED = '<disconnected>'
CONNECTING = '<connecting>' CONNECTING = '<connecting>'
HANDSHAKE = '<handshake>'
CONNECTED = '<connected>' CONNECTED = '<connected>'
@@ -49,6 +64,12 @@ class BrokerConnection(object):
'max_in_flight_requests_per_connection': 5, 'max_in_flight_requests_per_connection': 5,
'receive_buffer_bytes': None, 'receive_buffer_bytes': None,
'send_buffer_bytes': None, 'send_buffer_bytes': None,
'security_protocol': 'PLAINTEXT',
'ssl_context': None,
'ssl_check_hostname': True,
'ssl_cafile': None,
'ssl_certfile': None,
'ssl_keyfile': None,
'api_version': (0, 8, 2), # default to most restrictive 'api_version': (0, 8, 2), # default to most restrictive
'state_change_callback': lambda conn: True, 'state_change_callback': lambda conn: True,
} }
@@ -66,6 +87,9 @@ class BrokerConnection(object):
self.state = ConnectionStates.DISCONNECTED self.state = ConnectionStates.DISCONNECTED
self._sock = None self._sock = None
self._ssl_context = None
if self.config['ssl_context'] is not None:
self._ssl_context = self.config['ssl_context']
self._rbuffer = io.BytesIO() self._rbuffer = io.BytesIO()
self._receiving = False self._receiving = False
self._next_payload_bytes = 0 self._next_payload_bytes = 0
@@ -87,6 +111,8 @@ class BrokerConnection(object):
self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF,
self.config['send_buffer_bytes']) self.config['send_buffer_bytes'])
self._sock.setblocking(False) self._sock.setblocking(False)
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
self._wrap_ssl()
self.state = ConnectionStates.CONNECTING self.state = ConnectionStates.CONNECTING
self.last_attempt = time.time() self.last_attempt = time.time()
self.config['state_change_callback'](self) self.config['state_change_callback'](self)
@@ -103,7 +129,11 @@ class BrokerConnection(object):
# Connection succeeded # Connection succeeded
if not ret or ret == errno.EISCONN: if not ret or ret == errno.EISCONN:
log.debug('%s: established TCP connection', str(self)) log.debug('%s: established TCP connection', str(self))
self.state = ConnectionStates.CONNECTED if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
log.debug('%s: initiating SSL handshake', str(self))
self.state = ConnectionStates.HANDSHAKE
else:
self.state = ConnectionStates.CONNECTED
self.config['state_change_callback'](self) self.config['state_change_callback'](self)
# Connection failed # Connection failed
@@ -122,8 +152,60 @@ class BrokerConnection(object):
else: else:
pass pass
if self.state is ConnectionStates.HANDSHAKE:
if self._try_handshake():
log.debug('%s: completed SSL handshake.', str(self))
self.state = ConnectionStates.CONNECTED
self.config['state_change_callback'](self)
return self.state return self.state
def _wrap_ssl(self):
assert self.config['security_protocol'] in ('SSL', 'SASL_SSL')
if self._ssl_context is None:
log.debug('%s: configuring default SSL Context', str(self))
self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) # pylint: disable=no-member
self._ssl_context.options |= ssl.OP_NO_SSLv2 # pylint: disable=no-member
self._ssl_context.options |= ssl.OP_NO_SSLv3 # pylint: disable=no-member
self._ssl_context.verify_mode = ssl.CERT_OPTIONAL
if self.config['ssl_check_hostname']:
self._ssl_context.check_hostname = True
if self.config['ssl_cafile']:
log.info('%s: Loading SSL CA from %s', str(self), self.config['ssl_cafile'])
self._ssl_context.load_verify_locations(self.config['ssl_cafile'])
self._ssl_context.verify_mode = ssl.CERT_REQUIRED
if self.config['ssl_certfile'] and self.config['ssl_keyfile']:
log.info('%s: Loading SSL Cert from %s', str(self), self.config['ssl_certfile'])
log.info('%s: Loading SSL Key from %s', str(self), self.config['ssl_keyfile'])
self._ssl_context.load_cert_chain(
certfile=self.config['ssl_certfile'],
keyfile=self.config['ssl_keyfile'])
log.debug('%s: wrapping socket in ssl context', str(self))
try:
self._sock = self._ssl_context.wrap_socket(
self._sock,
server_hostname=self.host,
do_handshake_on_connect=False)
except ssl.SSLError:
log.exception('%s: Failed to wrap socket in SSLContext!', str(self))
self.close()
self.last_failure = time.time()
def _try_handshake(self):
assert self.config['security_protocol'] in ('SSL', 'SASL_SSL')
try:
self._sock.do_handshake()
return True
# old ssl in python2.6 will swallow all SSLErrors here...
except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
pass
except ssl.SSLZeroReturnError:
log.warning('SSL connection closed by server during handshake.')
self.close()
# Other SSLErrors will be raised to user
return False
def blacked_out(self): def blacked_out(self):
""" """
Return true if we are disconnected from the given node and can't Return true if we are disconnected from the given node and can't
@@ -140,8 +222,10 @@ class BrokerConnection(object):
return self.state is ConnectionStates.CONNECTED return self.state is ConnectionStates.CONNECTED
def connecting(self): def connecting(self):
"""Return True iff socket is in intermediate connecting state.""" """Returns True if still connecting (this may encompass several
return self.state is ConnectionStates.CONNECTING different states, such as SSL handshake, authorization, etc)."""
return self.state in (ConnectionStates.CONNECTING,
ConnectionStates.HANDSHAKE)
def disconnected(self): def disconnected(self):
"""Return True iff socket is closed""" """Return True iff socket is closed"""
@@ -260,6 +344,8 @@ class BrokerConnection(object):
# An extremely small, but non-zero, probability that there are # An extremely small, but non-zero, probability that there are
# more than 0 but not yet 4 bytes available to read # more than 0 but not yet 4 bytes available to read
self._rbuffer.write(self._sock.recv(4 - self._rbuffer.tell())) self._rbuffer.write(self._sock.recv(4 - self._rbuffer.tell()))
except ssl.SSLWantReadError:
return None
except ConnectionError as e: except ConnectionError as e:
if six.PY2 and e.errno == errno.EWOULDBLOCK: if six.PY2 and e.errno == errno.EWOULDBLOCK:
return None return None
@@ -286,6 +372,8 @@ class BrokerConnection(object):
staged_bytes = self._rbuffer.tell() staged_bytes = self._rbuffer.tell()
try: try:
self._rbuffer.write(self._sock.recv(self._next_payload_bytes - staged_bytes)) self._rbuffer.write(self._sock.recv(self._next_payload_bytes - staged_bytes))
except ssl.SSLWantReadError:
return None
except ConnectionError as e: except ConnectionError as e:
# Extremely small chance that we have exactly 4 bytes for a # Extremely small chance that we have exactly 4 bytes for a
# header, but nothing to read in the body yet # header, but nothing to read in the body yet