Switch BrokerConnection to (mostly) non-blocking IO.

- return kafka.Future on send()
 - recv is now non-blocking call that completes futures when possible
 - update KafkaClient to block on future completion
This commit is contained in:
Dana Powers
2015-12-17 17:29:54 -08:00
parent 799824535c
commit f1ad0247df
6 changed files with 389 additions and 157 deletions

View File

@@ -3,7 +3,6 @@ import copy
import functools import functools
import logging import logging
import random import random
import select
import time import time
import six import six
@@ -15,7 +14,9 @@ from kafka.common import (TopicAndPartition, BrokerMetadata, UnknownError,
LeaderNotAvailableError, UnknownTopicOrPartitionError, LeaderNotAvailableError, UnknownTopicOrPartitionError,
NotLeaderForPartitionError, ReplicaNotAvailableError) NotLeaderForPartitionError, ReplicaNotAvailableError)
from kafka.conn import collect_hosts, BrokerConnection, DEFAULT_SOCKET_TIMEOUT_SECONDS from kafka.conn import (
collect_hosts, BrokerConnection, DEFAULT_SOCKET_TIMEOUT_SECONDS,
ConnectionStates)
from kafka.protocol import KafkaProtocol from kafka.protocol import KafkaProtocol
@@ -45,7 +46,6 @@ class KafkaClient(object):
self.load_metadata_for_topics() # bootstrap with all metadata self.load_metadata_for_topics() # bootstrap with all metadata
################## ##################
# Private API # # Private API #
################## ##################
@@ -56,11 +56,14 @@ class KafkaClient(object):
if host_key not in self._conns: if host_key not in self._conns:
self._conns[host_key] = BrokerConnection( self._conns[host_key] = BrokerConnection(
host, port, host, port,
timeout=self.timeout, request_timeout_ms=self.timeout * 1000,
client_id=self.client_id client_id=self.client_id
) )
return self._conns[host_key] conn = self._conns[host_key]
while conn.connect() == ConnectionStates.CONNECTING:
pass
return conn
def _get_leader_for_partition(self, topic, partition): def _get_leader_for_partition(self, topic, partition):
""" """
@@ -137,16 +140,23 @@ class KafkaClient(object):
for (host, port) in hosts: for (host, port) in hosts:
conn = self._get_conn(host, port) conn = self._get_conn(host, port)
request = encoder_fn(payloads=payloads) if not conn.connected():
correlation_id = conn.send(request) log.warning("Skipping unconnected connection: %s", conn)
if correlation_id is None:
continue continue
response = conn.recv() request = encoder_fn(payloads=payloads)
if response is not None: future = conn.send(request)
decoded = decoder_fn(response)
return decoded
raise KafkaUnavailableError('All servers failed to process request') # Block
while not future.is_done:
conn.recv()
if future.failed():
log.error("Request failed: %s", future.exception)
continue
return decoder_fn(future.value)
raise KafkaUnavailableError('All servers failed to process request: %s' % hosts)
def _payloads_by_broker(self, payloads): def _payloads_by_broker(self, payloads):
payloads_by_broker = collections.defaultdict(list) payloads_by_broker = collections.defaultdict(list)
@@ -204,55 +214,59 @@ class KafkaClient(object):
# For each BrokerConnection keep the real socket so that we can use # For each BrokerConnection keep the real socket so that we can use
# a select to perform unblocking I/O # a select to perform unblocking I/O
connections_by_socket = {} connections_by_future = {}
for broker, broker_payloads in six.iteritems(payloads_by_broker): for broker, broker_payloads in six.iteritems(payloads_by_broker):
if broker is None: if broker is None:
failed_payloads(broker_payloads) failed_payloads(broker_payloads)
continue continue
conn = self._get_conn(broker.host, broker.port) conn = self._get_conn(broker.host, broker.port)
conn.connect()
if not conn.connected():
refresh_metadata = True
failed_payloads(broker_payloads)
continue
request = encoder_fn(payloads=broker_payloads) request = encoder_fn(payloads=broker_payloads)
# decoder_fn=None signal that the server is expected to not # decoder_fn=None signal that the server is expected to not
# send a response. This probably only applies to # send a response. This probably only applies to
# ProduceRequest w/ acks = 0 # ProduceRequest w/ acks = 0
expect_response = (decoder_fn is not None) expect_response = (decoder_fn is not None)
correlation_id = conn.send(request, expect_response=expect_response) future = conn.send(request, expect_response=expect_response)
if correlation_id is None: if future.failed():
refresh_metadata = True refresh_metadata = True
failed_payloads(broker_payloads) failed_payloads(broker_payloads)
log.warning('Error attempting to send request %s '
'to server %s', correlation_id, broker)
continue continue
if not expect_response: if not expect_response:
log.debug('Request %s does not expect a response '
'(skipping conn.recv)', correlation_id)
for payload in broker_payloads: for payload in broker_payloads:
topic_partition = (str(payload.topic), payload.partition) topic_partition = (str(payload.topic), payload.partition)
responses[topic_partition] = None responses[topic_partition] = None
continue continue
connections_by_socket[conn._read_fd] = (conn, broker) connections_by_future[future] = (conn, broker)
conn = None conn = None
while connections_by_socket: while connections_by_future:
sockets = connections_by_socket.keys() futures = list(connections_by_future.keys())
rlist, _, _ = select.select(sockets, [], [], None) for future in futures:
conn, broker = connections_by_socket.pop(rlist[0])
correlation_id = conn.next_correlation_id_recv()
response = conn.recv()
if response is None:
refresh_metadata = True
failed_payloads(payloads_by_broker[broker])
log.warning('Error receiving response to request %s '
'from server %s', correlation_id, broker)
continue
for payload_response in decoder_fn(response): if not future.is_done:
topic_partition = (str(payload_response.topic), conn, _ = connections_by_future[future]
payload_response.partition) conn.recv()
responses[topic_partition] = payload_response continue
_, broker = connections_by_future.pop(future)
if future.failed():
refresh_metadata = True
failed_payloads(payloads_by_broker[broker])
else:
for payload_response in decoder_fn(future.value):
topic_partition = (str(payload_response.topic),
payload_response.partition)
responses[topic_partition] = payload_response
if refresh_metadata: if refresh_metadata:
self.reset_all_metadata() self.reset_all_metadata()
@@ -392,7 +406,9 @@ class KafkaClient(object):
def reinit(self): def reinit(self):
for conn in self._conns.values(): for conn in self._conns.values():
conn.reinit() conn.close()
while conn.connect() == ConnectionStates.CONNECTING:
pass
def reset_topic_metadata(self, *topics): def reset_topic_metadata(self, *topics):
for topic in topics: for topic in topics:

View File

@@ -73,7 +73,7 @@ class Cluster(object):
def _bootstrap(self, hosts, timeout=2): def _bootstrap(self, hosts, timeout=2):
for host, port in hosts: for host, port in hosts:
conn = BrokerConnection(host, port, timeout) conn = BrokerConnection(host, port)
if not conn.connect(): if not conn.connect():
continue continue
self._brokers['bootstrap'] = conn self._brokers['bootstrap'] = conn

View File

@@ -93,6 +93,22 @@ class KafkaError(RuntimeError):
pass pass
class IllegalStateError(KafkaError):
pass
class RetriableError(KafkaError):
pass
class DisconnectError(KafkaError):
pass
class CorrelationIdError(KafkaError):
pass
class BrokerResponseError(KafkaError): class BrokerResponseError(KafkaError):
errno = None errno = None
message = None message = None

View File

@@ -1,15 +1,20 @@
from collections import deque import collections
import copy import copy
import errno
import logging import logging
import io
from random import shuffle from random import shuffle
from select import select from select import select
import socket import socket
import struct import struct
from threading import local from threading import local
import time
import six import six
import kafka.common as Errors
from kafka.common import ConnectionError from kafka.common import ConnectionError
from kafka.future import Future
from kafka.protocol.api import RequestHeader from kafka.protocol.api import RequestHeader
from kafka.protocol.types import Int32 from kafka.protocol.types import Int32
@@ -20,106 +25,244 @@ DEFAULT_SOCKET_TIMEOUT_SECONDS = 120
DEFAULT_KAFKA_PORT = 9092 DEFAULT_KAFKA_PORT = 9092
class BrokerConnection(local): class ConnectionStates(object):
def __init__(self, host, port, timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS, DISCONNECTED = 1
client_id='kafka-python-0.10.0', correlation_id=0): CONNECTING = 2
super(BrokerConnection, self).__init__() CONNECTED = 3
InFlightRequest = collections.namedtuple('InFlightRequest',
['request', 'response_type', 'correlation_id', 'future', 'timestamp'])
class BrokerConnection(object):
_receive_buffer_bytes = 32768
_send_buffer_bytes = 32768
_client_id = 'kafka-python-0.10.0'
_correlation_id = 0
_request_timeout_ms = 40000
def __init__(self, host, port, **kwargs):
self.host = host self.host = host
self.port = port self.port = port
self.timeout = timeout self.in_flight_requests = collections.deque()
self._write_fd = None
self._read_fd = None for config in ('receive_buffer_bytes', 'send_buffer_bytes',
self.correlation_id = correlation_id 'client_id', 'correlation_id', 'request_timeout_ms'):
self.client_id = client_id if config in kwargs:
self.in_flight_requests = deque() setattr(self, '_' + config, kwargs.pop(config))
self.state = ConnectionStates.DISCONNECTED
self._sock = None
self._rbuffer = io.BytesIO()
self._receiving = False
self._next_payload_bytes = 0
self._last_connection_attempt = None
self._last_connection_failure = None
def connect(self): def connect(self):
if self.connected(): """Attempt to connect and return ConnectionState"""
if self.state is ConnectionStates.DISCONNECTED:
self.close() self.close()
try: self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock = socket.create_connection((self.host, self.port), self.timeout) self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self._receive_buffer_bytes)
self._write_fd = sock.makefile('wb') self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self._send_buffer_bytes)
self._read_fd = sock.makefile('rb') self._sock.setblocking(False)
except socket.error: ret = self._sock.connect_ex((self.host, self.port))
log.exception("Error in BrokerConnection.connect()") self._last_connection_attempt = time.time()
return None
self.in_flight_requests.clear() if not ret or ret is errno.EISCONN:
return True self.state = ConnectionStates.CONNECTED
elif ret in (errno.EINPROGRESS, errno.EALREADY):
self.state = ConnectionStates.CONNECTING
else:
log.error('Connect attempt returned error %s. Disconnecting.', ret)
self.close()
self._last_connection_failure = time.time()
if self.state is ConnectionStates.CONNECTING:
# in non-blocking mode, use repeated calls to socket.connect_ex
# to check connection status
if time.time() > (self._request_timeout_ms / 1000.0) + self._last_connection_attempt:
log.error('Connection attempt timed out')
self.close() # error=TimeoutError ?
self._last_connection_failure = time.time()
ret = self._sock.connect_ex((self.host, self.port))
if not ret or ret is errno.EISCONN:
self.state = ConnectionStates.CONNECTED
elif ret is not errno.EALREADY:
log.error('Connect attempt returned error %s. Disconnecting.', ret)
self.close()
self._last_connection_failure = time.time()
return self.state
def connected(self): def connected(self):
return (self._read_fd is not None and self._write_fd is not None) return self.state is ConnectionStates.CONNECTED
def close(self): def close(self, error=None):
if self.connected(): if self._sock:
try: self._sock.close()
self._read_fd.close() self._sock = None
self._write_fd.close() self.state = ConnectionStates.DISCONNECTED
except socket.error:
log.exception("Error in BrokerConnection.close()") if error is None:
pass error = Errors.DisconnectError()
self._read_fd = None while self.in_flight_requests:
self._write_fd = None ifr = self.in_flight_requests.popleft()
ifr.future.failure(error)
self.in_flight_requests.clear() self.in_flight_requests.clear()
self._receiving = False
self._next_payload_bytes = 0
self._rbuffer.seek(0)
self._rbuffer.truncate()
def send(self, request, expect_response=True): def send(self, request, expect_response=True):
if not self.connected() and not self.connect(): """send request, return Future()
return None
self.correlation_id += 1 Can block on network if request is larger than send_buffer_bytes
"""
future = Future()
if not self.connected():
return future.failure(Errors.DisconnectError())
self._correlation_id += 1
header = RequestHeader(request, header = RequestHeader(request,
correlation_id=self.correlation_id, correlation_id=self._correlation_id,
client_id=self.client_id) client_id=self._client_id)
message = b''.join([header.encode(), request.encode()]) message = b''.join([header.encode(), request.encode()])
size = Int32.encode(len(message)) size = Int32.encode(len(message))
try: try:
self._write_fd.write(size) # In the future we might manage an internal write buffer
self._write_fd.write(message) # and send bytes asynchronously. For now, just block
self._write_fd.flush() # sending each request payload
except socket.error: self._sock.setblocking(True)
log.exception("Error in BrokerConnection.send(): %s", request) sent_bytes = self._sock.send(size)
self.close() assert sent_bytes == len(size)
return None sent_bytes = self._sock.send(message)
if expect_response: assert sent_bytes == len(message)
self.in_flight_requests.append((self.correlation_id, request.RESPONSE_TYPE)) self._sock.setblocking(False)
log.debug('Request %d: %s', self.correlation_id, request) except (AssertionError, socket.error) as e:
return self.correlation_id log.debug("Error in BrokerConnection.send(): %s", request)
self.close(error=e)
return future.failure(e)
log.debug('Request %d: %s', self._correlation_id, request)
def recv(self, timeout=None): if expect_response:
ifr = InFlightRequest(request=request,
correlation_id=self._correlation_id,
response_type=request.RESPONSE_TYPE,
future=future,
timestamp=time.time())
self.in_flight_requests.append(ifr)
else:
future.success(None)
return future
def recv(self, timeout=0):
"""Non-blocking network receive
Return response if available
"""
if not self.connected(): if not self.connected():
log.warning('Cannot recv: socket not connected')
# If requests are pending, we should close the socket and
# fail all the pending request futures
if self.in_flight_requests:
self.close()
return None return None
readable, _, _ = select([self._read_fd], [], [], timeout)
if not readable:
return None
if not self.in_flight_requests: if not self.in_flight_requests:
log.warning('No in-flight-requests to recv') log.warning('No in-flight-requests to recv')
return None return None
correlation_id, response_type = self.in_flight_requests.popleft()
# Current implementation does not use size self._fail_timed_out_requests()
# instead we read directly from the socket fd buffer
# alternatively, we could read size bytes into a separate buffer readable, _, _ = select([self._sock], [], [], timeout)
# and decode from that buffer (and verify buffer is empty afterwards) if not readable:
try: return None
size = Int32.decode(self._read_fd)
recv_correlation_id = Int32.decode(self._read_fd) # Not receiving is the state of reading the payload header
if correlation_id != recv_correlation_id: if not self._receiving:
raise RuntimeError('Correlation ids do not match!') try:
response = response_type.decode(self._read_fd) # An extremely small, but non-zero, probability that there are
except (RuntimeError, socket.error, struct.error): # more than 0 but not yet 4 bytes available to read
log.exception("Error in BrokerConnection.recv() for request %d", correlation_id) self._rbuffer.write(self._sock.recv(4 - self._rbuffer.tell()))
except socket.error as e:
if e.errno == errno.EWOULDBLOCK:
# This shouldn't happen after selecting above
# but just in case
return None
log.exception("Error receiving 4-byte payload header - closing socket")
self.close(error=e)
return None
if self._rbuffer.tell() == 4:
self._rbuffer.seek(0)
self._next_payload_bytes = Int32.decode(self._rbuffer)
# reset buffer and switch state to receiving payload bytes
self._rbuffer.seek(0)
self._rbuffer.truncate()
self._receiving = True
elif self._rbuffer.tell() > 4:
raise Errors.KafkaError('this should not happen - are you threading?')
if self._receiving:
staged_bytes = self._rbuffer.tell()
try:
self._rbuffer.write(self._sock.recv(self._next_payload_bytes - staged_bytes))
except socket.error as e:
# Extremely small chance that we have exactly 4 bytes for a
# header, but nothing to read in the body yet
if e.errno == errno.EWOULDBLOCK:
return None
log.exception()
self.close(error=e)
return None
staged_bytes = self._rbuffer.tell()
if staged_bytes > self._next_payload_bytes:
self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
if staged_bytes != self._next_payload_bytes:
return None
self._receiving = False
self._next_payload_bytes = 0
self._rbuffer.seek(0)
response = self._process_response(self._rbuffer)
self._rbuffer.seek(0)
self._rbuffer.truncate()
return response
def _process_response(self, read_buffer):
ifr = self.in_flight_requests.popleft()
# verify send/recv correlation ids match
recv_correlation_id = Int32.decode(read_buffer)
if ifr.correlation_id != recv_correlation_id:
error = Errors.CorrelationIdError(
'Correlation ids do not match: sent %d, recv %d'
% (ifr.correlation_id, recv_correlation_id))
ifr.future.fail(error)
self.close() self.close()
return None return None
log.debug('Response %d: %s', correlation_id, response)
# decode response
response = ifr.response_type.decode(read_buffer)
ifr.future.success(response)
log.debug('Response %d: %s', ifr.correlation_id, response)
return response return response
def next_correlation_id_recv(self): def _fail_timed_out_requests(self):
if len(self.in_flight_requests) == 0: now = time.time()
return None while self.in_flight_requests:
return self.in_flight_requests[0][0] next_timeout = self.in_flight_requests[0].timestamp + (self._request_timeout_ms / 1000.0)
if now < next_timeout:
def next_correlation_id_send(self): break
return self.correlation_id + 1 timed_out = self.in_flight_requests.popleft()
error = Errors.RequestTimedOutError('Request timed out after %s ms' % self._request_timeout_ms)
def __getnewargs__(self): timed_out.future.failure(error)
return (self.host, self.port, self.timeout)
def __repr__(self): def __repr__(self):
return "<BrokerConnection host=%s port=%d>" % (self.host, self.port) return "<BrokerConnection host=%s port=%d>" % (self.host, self.port)
@@ -149,13 +292,7 @@ def collect_hosts(hosts, randomize=True):
class KafkaConnection(local): class KafkaConnection(local):
""" """A socket connection to a single Kafka broker
A socket connection to a single Kafka broker
This class is _not_ thread safe. Each call to `send` must be followed
by a call to `recv` in order to get the correct response. Eventually,
we can do something in here to facilitate multiplexed requests/responses
since the Kafka API includes a correlation id.
Arguments: Arguments:
host: the host name or IP address of a kafka broker host: the host name or IP address of a kafka broker

51
kafka/future.py Normal file
View File

@@ -0,0 +1,51 @@
from kafka.common import RetriableError, IllegalStateError
class Future(object):
def __init__(self):
self.is_done = False
self.value = None
self.exception = None
self._callbacks = []
self._errbacks = []
def succeeded(self):
return self.is_done and not self.exception
def failed(self):
return self.is_done and self.exception
def retriable(self):
return isinstance(self.exception, RetriableError)
def success(self, value):
if self.is_done:
raise IllegalStateError('Invalid attempt to complete a request future which is already complete')
self.value = value
self.is_done = True
for f in self._callbacks:
f(value)
return self
def failure(self, e):
if self.is_done:
raise IllegalStateError('Invalid attempt to complete a request future which is already complete')
self.exception = e
self.is_done = True
for f in self._errbacks:
f(e)
return self
def add_callback(self, f):
if self.is_done and not self.exception:
f(self.value)
else:
self._callbacks.append(f)
return self
def add_errback(self, f):
if self.is_done and self.exception:
f(self.exception)
else:
self._errbacks.append(f)
return self

View File

@@ -14,6 +14,7 @@ from kafka.common import (
KafkaTimeoutError, ConnectionError KafkaTimeoutError, ConnectionError
) )
from kafka.conn import KafkaConnection from kafka.conn import KafkaConnection
from kafka.future import Future
from kafka.protocol import KafkaProtocol, create_message from kafka.protocol import KafkaProtocol, create_message
from kafka.protocol.metadata import MetadataResponse from kafka.protocol.metadata import MetadataResponse
@@ -23,6 +24,17 @@ NO_ERROR = 0
UNKNOWN_TOPIC_OR_PARTITION = 3 UNKNOWN_TOPIC_OR_PARTITION = 3
NO_LEADER = 5 NO_LEADER = 5
def mock_conn(conn, success=True):
mocked = MagicMock()
mocked.connected.return_value = True
if success:
mocked.send.return_value = Future().success(True)
else:
mocked.send.return_value = Future().failure(Exception())
conn.return_value = mocked
class TestKafkaClient(unittest.TestCase): class TestKafkaClient(unittest.TestCase):
def test_init_with_list(self): def test_init_with_list(self):
with patch.object(KafkaClient, 'load_metadata_for_topics'): with patch.object(KafkaClient, 'load_metadata_for_topics'):
@@ -48,32 +60,30 @@ class TestKafkaClient(unittest.TestCase):
sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]), sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]),
sorted(client.hosts)) sorted(client.hosts))
def test_send_broker_unaware_request_fail(self): @patch.object(KafkaClient, '_get_conn')
@patch.object(KafkaClient, 'load_metadata_for_topics')
def test_send_broker_unaware_request_fail(self, load_metadata, conn):
mocked_conns = { mocked_conns = {
('kafka01', 9092): MagicMock(), ('kafka01', 9092): MagicMock(),
('kafka02', 9092): MagicMock() ('kafka02', 9092): MagicMock()
} }
for val in mocked_conns.values():
# inject KafkaConnection side effects mock_conn(val, success=False)
mocked_conns[('kafka01', 9092)].send.return_value = None
mocked_conns[('kafka02', 9092)].send.return_value = None
def mock_get_conn(host, port): def mock_get_conn(host, port):
return mocked_conns[(host, port)] return mocked_conns[(host, port)]
conn.side_effect = mock_get_conn
# patch to avoid making requests before we want it client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092'])
with patch.object(KafkaClient, 'load_metadata_for_topics'):
with patch.object(KafkaClient, '_get_conn', side_effect=mock_get_conn):
client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092'])
req = KafkaProtocol.encode_metadata_request() req = KafkaProtocol.encode_metadata_request()
with self.assertRaises(KafkaUnavailableError): with self.assertRaises(KafkaUnavailableError):
client._send_broker_unaware_request(payloads=['fake request'], client._send_broker_unaware_request(payloads=['fake request'],
encoder_fn=MagicMock(return_value='fake encoded message'), encoder_fn=MagicMock(return_value='fake encoded message'),
decoder_fn=lambda x: x) decoder_fn=lambda x: x)
for key, conn in six.iteritems(mocked_conns): for key, conn in six.iteritems(mocked_conns):
conn.send.assert_called_with('fake encoded message') conn.send.assert_called_with('fake encoded message')
def test_send_broker_unaware_request(self): def test_send_broker_unaware_request(self):
mocked_conns = { mocked_conns = {
@@ -82,9 +92,11 @@ class TestKafkaClient(unittest.TestCase):
('kafka03', 9092): MagicMock() ('kafka03', 9092): MagicMock()
} }
# inject KafkaConnection side effects # inject KafkaConnection side effects
mocked_conns[('kafka01', 9092)].send.return_value = None mock_conn(mocked_conns[('kafka01', 9092)], success=False)
mocked_conns[('kafka02', 9092)].recv.return_value = 'valid response' mock_conn(mocked_conns[('kafka03', 9092)], success=False)
mocked_conns[('kafka03', 9092)].send.return_value = None future = Future()
mocked_conns[('kafka02', 9092)].send.return_value = future
mocked_conns[('kafka02', 9092)].recv.side_effect = lambda: future.success('valid response')
def mock_get_conn(host, port): def mock_get_conn(host, port):
return mocked_conns[(host, port)] return mocked_conns[(host, port)]
@@ -101,11 +113,11 @@ class TestKafkaClient(unittest.TestCase):
self.assertEqual('valid response', resp) self.assertEqual('valid response', resp)
mocked_conns[('kafka02', 9092)].recv.assert_called_once_with() mocked_conns[('kafka02', 9092)].recv.assert_called_once_with()
@patch('kafka.client.BrokerConnection') @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol') @patch('kafka.client.KafkaProtocol')
def test_load_metadata(self, protocol, conn): def test_load_metadata(self, protocol, conn):
conn.recv.return_value = 'response' # anything but None mock_conn(conn)
brokers = [ brokers = [
BrokerMetadata(0, 'broker_1', 4567), BrokerMetadata(0, 'broker_1', 4567),
@@ -151,11 +163,11 @@ class TestKafkaClient(unittest.TestCase):
# This should not raise # This should not raise
client.load_metadata_for_topics('topic_no_leader') client.load_metadata_for_topics('topic_no_leader')
@patch('kafka.client.BrokerConnection') @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol') @patch('kafka.client.KafkaProtocol')
def test_has_metadata_for_topic(self, protocol, conn): def test_has_metadata_for_topic(self, protocol, conn):
conn.recv.return_value = 'response' # anything but None mock_conn(conn)
brokers = [ brokers = [
BrokerMetadata(0, 'broker_1', 4567), BrokerMetadata(0, 'broker_1', 4567),
@@ -181,11 +193,11 @@ class TestKafkaClient(unittest.TestCase):
# Topic with partition metadata, but no leaders return True # Topic with partition metadata, but no leaders return True
self.assertTrue(client.has_metadata_for_topic('topic_noleaders')) self.assertTrue(client.has_metadata_for_topic('topic_noleaders'))
@patch('kafka.client.BrokerConnection') @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol.decode_metadata_response') @patch('kafka.client.KafkaProtocol.decode_metadata_response')
def test_ensure_topic_exists(self, decode_metadata_response, conn): def test_ensure_topic_exists(self, decode_metadata_response, conn):
conn.recv.return_value = 'response' # anything but None mock_conn(conn)
brokers = [ brokers = [
BrokerMetadata(0, 'broker_1', 4567), BrokerMetadata(0, 'broker_1', 4567),
@@ -213,12 +225,12 @@ class TestKafkaClient(unittest.TestCase):
# This should not raise # This should not raise
client.ensure_topic_exists('topic_noleaders', timeout=1) client.ensure_topic_exists('topic_noleaders', timeout=1)
@patch('kafka.client.BrokerConnection') @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol') @patch('kafka.client.KafkaProtocol')
def test_get_leader_for_partitions_reloads_metadata(self, protocol, conn): def test_get_leader_for_partitions_reloads_metadata(self, protocol, conn):
"Get leader for partitions reload metadata if it is not available" "Get leader for partitions reload metadata if it is not available"
conn.recv.return_value = 'response' # anything but None mock_conn(conn)
brokers = [ brokers = [
BrokerMetadata(0, 'broker_1', 4567), BrokerMetadata(0, 'broker_1', 4567),
@@ -251,11 +263,11 @@ class TestKafkaClient(unittest.TestCase):
TopicAndPartition('topic_one_partition', 0): brokers[0]}, TopicAndPartition('topic_one_partition', 0): brokers[0]},
client.topics_to_brokers) client.topics_to_brokers)
@patch('kafka.client.BrokerConnection') @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol') @patch('kafka.client.KafkaProtocol')
def test_get_leader_for_unassigned_partitions(self, protocol, conn): def test_get_leader_for_unassigned_partitions(self, protocol, conn):
conn.recv.return_value = 'response' # anything but None mock_conn(conn)
brokers = [ brokers = [
BrokerMetadata(0, 'broker_1', 4567), BrokerMetadata(0, 'broker_1', 4567),
@@ -278,11 +290,11 @@ class TestKafkaClient(unittest.TestCase):
with self.assertRaises(UnknownTopicOrPartitionError): with self.assertRaises(UnknownTopicOrPartitionError):
client._get_leader_for_partition('topic_unknown', 0) client._get_leader_for_partition('topic_unknown', 0)
@patch('kafka.client.BrokerConnection') @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol') @patch('kafka.client.KafkaProtocol')
def test_get_leader_exceptions_when_noleader(self, protocol, conn): def test_get_leader_exceptions_when_noleader(self, protocol, conn):
conn.recv.return_value = 'response' # anything but None mock_conn(conn)
brokers = [ brokers = [
BrokerMetadata(0, 'broker_1', 4567), BrokerMetadata(0, 'broker_1', 4567),
@@ -325,10 +337,10 @@ class TestKafkaClient(unittest.TestCase):
self.assertEqual(brokers[0], client._get_leader_for_partition('topic_noleader', 0)) self.assertEqual(brokers[0], client._get_leader_for_partition('topic_noleader', 0))
self.assertEqual(brokers[1], client._get_leader_for_partition('topic_noleader', 1)) self.assertEqual(brokers[1], client._get_leader_for_partition('topic_noleader', 1))
@patch('kafka.client.BrokerConnection') @patch.object(KafkaClient, '_get_conn')
@patch('kafka.client.KafkaProtocol') @patch('kafka.client.KafkaProtocol')
def test_send_produce_request_raises_when_noleader(self, protocol, conn): def test_send_produce_request_raises_when_noleader(self, protocol, conn):
conn.recv.return_value = 'response' # anything but None mock_conn(conn)
brokers = [ brokers = [
BrokerMetadata(0, 'broker_1', 4567), BrokerMetadata(0, 'broker_1', 4567),
@@ -352,11 +364,11 @@ class TestKafkaClient(unittest.TestCase):
with self.assertRaises(LeaderNotAvailableError): with self.assertRaises(LeaderNotAvailableError):
client.send_produce_request(requests) client.send_produce_request(requests)
@patch('kafka.client.BrokerConnection') @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol') @patch('kafka.client.KafkaProtocol')
def test_send_produce_request_raises_when_topic_unknown(self, protocol, conn): def test_send_produce_request_raises_when_topic_unknown(self, protocol, conn):
conn.recv.return_value = 'response' # anything but None mock_conn(conn)
brokers = [ brokers = [
BrokerMetadata(0, 'broker_1', 4567), BrokerMetadata(0, 'broker_1', 4567),