Switch BrokerConnection to (mostly) non-blocking IO.
- return kafka.Future on send() - recv is now non-blocking call that completes futures when possible - update KafkaClient to block on future completion
This commit is contained in:
@@ -3,7 +3,6 @@ import copy
|
||||
import functools
|
||||
import logging
|
||||
import random
|
||||
import select
|
||||
import time
|
||||
|
||||
import six
|
||||
@@ -15,7 +14,9 @@ from kafka.common import (TopicAndPartition, BrokerMetadata, UnknownError,
|
||||
LeaderNotAvailableError, UnknownTopicOrPartitionError,
|
||||
NotLeaderForPartitionError, ReplicaNotAvailableError)
|
||||
|
||||
from kafka.conn import collect_hosts, BrokerConnection, DEFAULT_SOCKET_TIMEOUT_SECONDS
|
||||
from kafka.conn import (
|
||||
collect_hosts, BrokerConnection, DEFAULT_SOCKET_TIMEOUT_SECONDS,
|
||||
ConnectionStates)
|
||||
from kafka.protocol import KafkaProtocol
|
||||
|
||||
|
||||
@@ -45,7 +46,6 @@ class KafkaClient(object):
|
||||
|
||||
self.load_metadata_for_topics() # bootstrap with all metadata
|
||||
|
||||
|
||||
##################
|
||||
# Private API #
|
||||
##################
|
||||
@@ -56,11 +56,14 @@ class KafkaClient(object):
|
||||
if host_key not in self._conns:
|
||||
self._conns[host_key] = BrokerConnection(
|
||||
host, port,
|
||||
timeout=self.timeout,
|
||||
request_timeout_ms=self.timeout * 1000,
|
||||
client_id=self.client_id
|
||||
)
|
||||
|
||||
return self._conns[host_key]
|
||||
conn = self._conns[host_key]
|
||||
while conn.connect() == ConnectionStates.CONNECTING:
|
||||
pass
|
||||
return conn
|
||||
|
||||
def _get_leader_for_partition(self, topic, partition):
|
||||
"""
|
||||
@@ -137,16 +140,23 @@ class KafkaClient(object):
|
||||
|
||||
for (host, port) in hosts:
|
||||
conn = self._get_conn(host, port)
|
||||
request = encoder_fn(payloads=payloads)
|
||||
correlation_id = conn.send(request)
|
||||
if correlation_id is None:
|
||||
if not conn.connected():
|
||||
log.warning("Skipping unconnected connection: %s", conn)
|
||||
continue
|
||||
response = conn.recv()
|
||||
if response is not None:
|
||||
decoded = decoder_fn(response)
|
||||
return decoded
|
||||
request = encoder_fn(payloads=payloads)
|
||||
future = conn.send(request)
|
||||
|
||||
raise KafkaUnavailableError('All servers failed to process request')
|
||||
# Block
|
||||
while not future.is_done:
|
||||
conn.recv()
|
||||
|
||||
if future.failed():
|
||||
log.error("Request failed: %s", future.exception)
|
||||
continue
|
||||
|
||||
return decoder_fn(future.value)
|
||||
|
||||
raise KafkaUnavailableError('All servers failed to process request: %s' % hosts)
|
||||
|
||||
def _payloads_by_broker(self, payloads):
|
||||
payloads_by_broker = collections.defaultdict(list)
|
||||
@@ -204,52 +214,56 @@ class KafkaClient(object):
|
||||
|
||||
# For each BrokerConnection keep the real socket so that we can use
|
||||
# a select to perform unblocking I/O
|
||||
connections_by_socket = {}
|
||||
connections_by_future = {}
|
||||
for broker, broker_payloads in six.iteritems(payloads_by_broker):
|
||||
if broker is None:
|
||||
failed_payloads(broker_payloads)
|
||||
continue
|
||||
|
||||
conn = self._get_conn(broker.host, broker.port)
|
||||
conn.connect()
|
||||
if not conn.connected():
|
||||
refresh_metadata = True
|
||||
failed_payloads(broker_payloads)
|
||||
continue
|
||||
|
||||
request = encoder_fn(payloads=broker_payloads)
|
||||
# decoder_fn=None signal that the server is expected to not
|
||||
# send a response. This probably only applies to
|
||||
# ProduceRequest w/ acks = 0
|
||||
expect_response = (decoder_fn is not None)
|
||||
correlation_id = conn.send(request, expect_response=expect_response)
|
||||
future = conn.send(request, expect_response=expect_response)
|
||||
|
||||
if correlation_id is None:
|
||||
if future.failed():
|
||||
refresh_metadata = True
|
||||
failed_payloads(broker_payloads)
|
||||
log.warning('Error attempting to send request %s '
|
||||
'to server %s', correlation_id, broker)
|
||||
continue
|
||||
|
||||
if not expect_response:
|
||||
log.debug('Request %s does not expect a response '
|
||||
'(skipping conn.recv)', correlation_id)
|
||||
for payload in broker_payloads:
|
||||
topic_partition = (str(payload.topic), payload.partition)
|
||||
responses[topic_partition] = None
|
||||
continue
|
||||
|
||||
connections_by_socket[conn._read_fd] = (conn, broker)
|
||||
connections_by_future[future] = (conn, broker)
|
||||
|
||||
conn = None
|
||||
while connections_by_socket:
|
||||
sockets = connections_by_socket.keys()
|
||||
rlist, _, _ = select.select(sockets, [], [], None)
|
||||
conn, broker = connections_by_socket.pop(rlist[0])
|
||||
correlation_id = conn.next_correlation_id_recv()
|
||||
response = conn.recv()
|
||||
if response is None:
|
||||
refresh_metadata = True
|
||||
failed_payloads(payloads_by_broker[broker])
|
||||
log.warning('Error receiving response to request %s '
|
||||
'from server %s', correlation_id, broker)
|
||||
while connections_by_future:
|
||||
futures = list(connections_by_future.keys())
|
||||
for future in futures:
|
||||
|
||||
if not future.is_done:
|
||||
conn, _ = connections_by_future[future]
|
||||
conn.recv()
|
||||
continue
|
||||
|
||||
for payload_response in decoder_fn(response):
|
||||
_, broker = connections_by_future.pop(future)
|
||||
if future.failed():
|
||||
refresh_metadata = True
|
||||
failed_payloads(payloads_by_broker[broker])
|
||||
|
||||
else:
|
||||
for payload_response in decoder_fn(future.value):
|
||||
topic_partition = (str(payload_response.topic),
|
||||
payload_response.partition)
|
||||
responses[topic_partition] = payload_response
|
||||
@@ -392,7 +406,9 @@ class KafkaClient(object):
|
||||
|
||||
def reinit(self):
|
||||
for conn in self._conns.values():
|
||||
conn.reinit()
|
||||
conn.close()
|
||||
while conn.connect() == ConnectionStates.CONNECTING:
|
||||
pass
|
||||
|
||||
def reset_topic_metadata(self, *topics):
|
||||
for topic in topics:
|
||||
|
@@ -73,7 +73,7 @@ class Cluster(object):
|
||||
|
||||
def _bootstrap(self, hosts, timeout=2):
|
||||
for host, port in hosts:
|
||||
conn = BrokerConnection(host, port, timeout)
|
||||
conn = BrokerConnection(host, port)
|
||||
if not conn.connect():
|
||||
continue
|
||||
self._brokers['bootstrap'] = conn
|
||||
|
@@ -93,6 +93,22 @@ class KafkaError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class IllegalStateError(KafkaError):
|
||||
pass
|
||||
|
||||
|
||||
class RetriableError(KafkaError):
|
||||
pass
|
||||
|
||||
|
||||
class DisconnectError(KafkaError):
|
||||
pass
|
||||
|
||||
|
||||
class CorrelationIdError(KafkaError):
|
||||
pass
|
||||
|
||||
|
||||
class BrokerResponseError(KafkaError):
|
||||
errno = None
|
||||
message = None
|
||||
|
299
kafka/conn.py
299
kafka/conn.py
@@ -1,15 +1,20 @@
|
||||
from collections import deque
|
||||
import collections
|
||||
import copy
|
||||
import errno
|
||||
import logging
|
||||
import io
|
||||
from random import shuffle
|
||||
from select import select
|
||||
import socket
|
||||
import struct
|
||||
from threading import local
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
import kafka.common as Errors
|
||||
from kafka.common import ConnectionError
|
||||
from kafka.future import Future
|
||||
from kafka.protocol.api import RequestHeader
|
||||
from kafka.protocol.types import Int32
|
||||
|
||||
@@ -20,106 +25,244 @@ DEFAULT_SOCKET_TIMEOUT_SECONDS = 120
|
||||
DEFAULT_KAFKA_PORT = 9092
|
||||
|
||||
|
||||
class BrokerConnection(local):
|
||||
def __init__(self, host, port, timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS,
|
||||
client_id='kafka-python-0.10.0', correlation_id=0):
|
||||
super(BrokerConnection, self).__init__()
|
||||
class ConnectionStates(object):
|
||||
DISCONNECTED = 1
|
||||
CONNECTING = 2
|
||||
CONNECTED = 3
|
||||
|
||||
|
||||
InFlightRequest = collections.namedtuple('InFlightRequest',
|
||||
['request', 'response_type', 'correlation_id', 'future', 'timestamp'])
|
||||
|
||||
|
||||
class BrokerConnection(object):
|
||||
_receive_buffer_bytes = 32768
|
||||
_send_buffer_bytes = 32768
|
||||
_client_id = 'kafka-python-0.10.0'
|
||||
_correlation_id = 0
|
||||
_request_timeout_ms = 40000
|
||||
|
||||
def __init__(self, host, port, **kwargs):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout = timeout
|
||||
self._write_fd = None
|
||||
self._read_fd = None
|
||||
self.correlation_id = correlation_id
|
||||
self.client_id = client_id
|
||||
self.in_flight_requests = deque()
|
||||
self.in_flight_requests = collections.deque()
|
||||
|
||||
for config in ('receive_buffer_bytes', 'send_buffer_bytes',
|
||||
'client_id', 'correlation_id', 'request_timeout_ms'):
|
||||
if config in kwargs:
|
||||
setattr(self, '_' + config, kwargs.pop(config))
|
||||
|
||||
self.state = ConnectionStates.DISCONNECTED
|
||||
self._sock = None
|
||||
self._rbuffer = io.BytesIO()
|
||||
self._receiving = False
|
||||
self._next_payload_bytes = 0
|
||||
self._last_connection_attempt = None
|
||||
self._last_connection_failure = None
|
||||
|
||||
def connect(self):
|
||||
if self.connected():
|
||||
"""Attempt to connect and return ConnectionState"""
|
||||
if self.state is ConnectionStates.DISCONNECTED:
|
||||
self.close()
|
||||
try:
|
||||
sock = socket.create_connection((self.host, self.port), self.timeout)
|
||||
self._write_fd = sock.makefile('wb')
|
||||
self._read_fd = sock.makefile('rb')
|
||||
except socket.error:
|
||||
log.exception("Error in BrokerConnection.connect()")
|
||||
return None
|
||||
self.in_flight_requests.clear()
|
||||
return True
|
||||
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self._receive_buffer_bytes)
|
||||
self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self._send_buffer_bytes)
|
||||
self._sock.setblocking(False)
|
||||
ret = self._sock.connect_ex((self.host, self.port))
|
||||
self._last_connection_attempt = time.time()
|
||||
|
||||
if not ret or ret is errno.EISCONN:
|
||||
self.state = ConnectionStates.CONNECTED
|
||||
elif ret in (errno.EINPROGRESS, errno.EALREADY):
|
||||
self.state = ConnectionStates.CONNECTING
|
||||
else:
|
||||
log.error('Connect attempt returned error %s. Disconnecting.', ret)
|
||||
self.close()
|
||||
self._last_connection_failure = time.time()
|
||||
|
||||
if self.state is ConnectionStates.CONNECTING:
|
||||
# in non-blocking mode, use repeated calls to socket.connect_ex
|
||||
# to check connection status
|
||||
if time.time() > (self._request_timeout_ms / 1000.0) + self._last_connection_attempt:
|
||||
log.error('Connection attempt timed out')
|
||||
self.close() # error=TimeoutError ?
|
||||
self._last_connection_failure = time.time()
|
||||
|
||||
ret = self._sock.connect_ex((self.host, self.port))
|
||||
if not ret or ret is errno.EISCONN:
|
||||
self.state = ConnectionStates.CONNECTED
|
||||
elif ret is not errno.EALREADY:
|
||||
log.error('Connect attempt returned error %s. Disconnecting.', ret)
|
||||
self.close()
|
||||
self._last_connection_failure = time.time()
|
||||
return self.state
|
||||
|
||||
def connected(self):
|
||||
return (self._read_fd is not None and self._write_fd is not None)
|
||||
return self.state is ConnectionStates.CONNECTED
|
||||
|
||||
def close(self):
|
||||
if self.connected():
|
||||
try:
|
||||
self._read_fd.close()
|
||||
self._write_fd.close()
|
||||
except socket.error:
|
||||
log.exception("Error in BrokerConnection.close()")
|
||||
pass
|
||||
self._read_fd = None
|
||||
self._write_fd = None
|
||||
def close(self, error=None):
|
||||
if self._sock:
|
||||
self._sock.close()
|
||||
self._sock = None
|
||||
self.state = ConnectionStates.DISCONNECTED
|
||||
|
||||
if error is None:
|
||||
error = Errors.DisconnectError()
|
||||
while self.in_flight_requests:
|
||||
ifr = self.in_flight_requests.popleft()
|
||||
ifr.future.failure(error)
|
||||
self.in_flight_requests.clear()
|
||||
self._receiving = False
|
||||
self._next_payload_bytes = 0
|
||||
self._rbuffer.seek(0)
|
||||
self._rbuffer.truncate()
|
||||
|
||||
def send(self, request, expect_response=True):
|
||||
if not self.connected() and not self.connect():
|
||||
return None
|
||||
self.correlation_id += 1
|
||||
"""send request, return Future()
|
||||
|
||||
Can block on network if request is larger than send_buffer_bytes
|
||||
"""
|
||||
future = Future()
|
||||
if not self.connected():
|
||||
return future.failure(Errors.DisconnectError())
|
||||
self._correlation_id += 1
|
||||
header = RequestHeader(request,
|
||||
correlation_id=self.correlation_id,
|
||||
client_id=self.client_id)
|
||||
correlation_id=self._correlation_id,
|
||||
client_id=self._client_id)
|
||||
message = b''.join([header.encode(), request.encode()])
|
||||
size = Int32.encode(len(message))
|
||||
try:
|
||||
self._write_fd.write(size)
|
||||
self._write_fd.write(message)
|
||||
self._write_fd.flush()
|
||||
except socket.error:
|
||||
log.exception("Error in BrokerConnection.send(): %s", request)
|
||||
# In the future we might manage an internal write buffer
|
||||
# and send bytes asynchronously. For now, just block
|
||||
# sending each request payload
|
||||
self._sock.setblocking(True)
|
||||
sent_bytes = self._sock.send(size)
|
||||
assert sent_bytes == len(size)
|
||||
sent_bytes = self._sock.send(message)
|
||||
assert sent_bytes == len(message)
|
||||
self._sock.setblocking(False)
|
||||
except (AssertionError, socket.error) as e:
|
||||
log.debug("Error in BrokerConnection.send(): %s", request)
|
||||
self.close(error=e)
|
||||
return future.failure(e)
|
||||
log.debug('Request %d: %s', self._correlation_id, request)
|
||||
|
||||
if expect_response:
|
||||
ifr = InFlightRequest(request=request,
|
||||
correlation_id=self._correlation_id,
|
||||
response_type=request.RESPONSE_TYPE,
|
||||
future=future,
|
||||
timestamp=time.time())
|
||||
self.in_flight_requests.append(ifr)
|
||||
else:
|
||||
future.success(None)
|
||||
|
||||
return future
|
||||
|
||||
def recv(self, timeout=0):
|
||||
"""Non-blocking network receive
|
||||
|
||||
Return response if available
|
||||
"""
|
||||
if not self.connected():
|
||||
log.warning('Cannot recv: socket not connected')
|
||||
# If requests are pending, we should close the socket and
|
||||
# fail all the pending request futures
|
||||
if self.in_flight_requests:
|
||||
self.close()
|
||||
return None
|
||||
if expect_response:
|
||||
self.in_flight_requests.append((self.correlation_id, request.RESPONSE_TYPE))
|
||||
log.debug('Request %d: %s', self.correlation_id, request)
|
||||
return self.correlation_id
|
||||
|
||||
def recv(self, timeout=None):
|
||||
if not self.connected():
|
||||
return None
|
||||
readable, _, _ = select([self._read_fd], [], [], timeout)
|
||||
if not readable:
|
||||
return None
|
||||
if not self.in_flight_requests:
|
||||
log.warning('No in-flight-requests to recv')
|
||||
return None
|
||||
correlation_id, response_type = self.in_flight_requests.popleft()
|
||||
# Current implementation does not use size
|
||||
# instead we read directly from the socket fd buffer
|
||||
# alternatively, we could read size bytes into a separate buffer
|
||||
# and decode from that buffer (and verify buffer is empty afterwards)
|
||||
try:
|
||||
size = Int32.decode(self._read_fd)
|
||||
recv_correlation_id = Int32.decode(self._read_fd)
|
||||
if correlation_id != recv_correlation_id:
|
||||
raise RuntimeError('Correlation ids do not match!')
|
||||
response = response_type.decode(self._read_fd)
|
||||
except (RuntimeError, socket.error, struct.error):
|
||||
log.exception("Error in BrokerConnection.recv() for request %d", correlation_id)
|
||||
self.close()
|
||||
|
||||
self._fail_timed_out_requests()
|
||||
|
||||
readable, _, _ = select([self._sock], [], [], timeout)
|
||||
if not readable:
|
||||
return None
|
||||
log.debug('Response %d: %s', correlation_id, response)
|
||||
|
||||
# Not receiving is the state of reading the payload header
|
||||
if not self._receiving:
|
||||
try:
|
||||
# An extremely small, but non-zero, probability that there are
|
||||
# more than 0 but not yet 4 bytes available to read
|
||||
self._rbuffer.write(self._sock.recv(4 - self._rbuffer.tell()))
|
||||
except socket.error as e:
|
||||
if e.errno == errno.EWOULDBLOCK:
|
||||
# This shouldn't happen after selecting above
|
||||
# but just in case
|
||||
return None
|
||||
log.exception("Error receiving 4-byte payload header - closing socket")
|
||||
self.close(error=e)
|
||||
return None
|
||||
|
||||
if self._rbuffer.tell() == 4:
|
||||
self._rbuffer.seek(0)
|
||||
self._next_payload_bytes = Int32.decode(self._rbuffer)
|
||||
# reset buffer and switch state to receiving payload bytes
|
||||
self._rbuffer.seek(0)
|
||||
self._rbuffer.truncate()
|
||||
self._receiving = True
|
||||
elif self._rbuffer.tell() > 4:
|
||||
raise Errors.KafkaError('this should not happen - are you threading?')
|
||||
|
||||
if self._receiving:
|
||||
staged_bytes = self._rbuffer.tell()
|
||||
try:
|
||||
self._rbuffer.write(self._sock.recv(self._next_payload_bytes - staged_bytes))
|
||||
except socket.error as e:
|
||||
# Extremely small chance that we have exactly 4 bytes for a
|
||||
# header, but nothing to read in the body yet
|
||||
if e.errno == errno.EWOULDBLOCK:
|
||||
return None
|
||||
log.exception()
|
||||
self.close(error=e)
|
||||
return None
|
||||
|
||||
staged_bytes = self._rbuffer.tell()
|
||||
if staged_bytes > self._next_payload_bytes:
|
||||
self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
|
||||
|
||||
if staged_bytes != self._next_payload_bytes:
|
||||
return None
|
||||
|
||||
self._receiving = False
|
||||
self._next_payload_bytes = 0
|
||||
self._rbuffer.seek(0)
|
||||
response = self._process_response(self._rbuffer)
|
||||
self._rbuffer.seek(0)
|
||||
self._rbuffer.truncate()
|
||||
return response
|
||||
|
||||
def next_correlation_id_recv(self):
|
||||
if len(self.in_flight_requests) == 0:
|
||||
def _process_response(self, read_buffer):
|
||||
ifr = self.in_flight_requests.popleft()
|
||||
|
||||
# verify send/recv correlation ids match
|
||||
recv_correlation_id = Int32.decode(read_buffer)
|
||||
if ifr.correlation_id != recv_correlation_id:
|
||||
error = Errors.CorrelationIdError(
|
||||
'Correlation ids do not match: sent %d, recv %d'
|
||||
% (ifr.correlation_id, recv_correlation_id))
|
||||
ifr.future.fail(error)
|
||||
self.close()
|
||||
return None
|
||||
return self.in_flight_requests[0][0]
|
||||
|
||||
def next_correlation_id_send(self):
|
||||
return self.correlation_id + 1
|
||||
# decode response
|
||||
response = ifr.response_type.decode(read_buffer)
|
||||
ifr.future.success(response)
|
||||
log.debug('Response %d: %s', ifr.correlation_id, response)
|
||||
return response
|
||||
|
||||
def __getnewargs__(self):
|
||||
return (self.host, self.port, self.timeout)
|
||||
def _fail_timed_out_requests(self):
|
||||
now = time.time()
|
||||
while self.in_flight_requests:
|
||||
next_timeout = self.in_flight_requests[0].timestamp + (self._request_timeout_ms / 1000.0)
|
||||
if now < next_timeout:
|
||||
break
|
||||
timed_out = self.in_flight_requests.popleft()
|
||||
error = Errors.RequestTimedOutError('Request timed out after %s ms' % self._request_timeout_ms)
|
||||
timed_out.future.failure(error)
|
||||
|
||||
def __repr__(self):
|
||||
return "<BrokerConnection host=%s port=%d>" % (self.host, self.port)
|
||||
@@ -149,13 +292,7 @@ def collect_hosts(hosts, randomize=True):
|
||||
|
||||
|
||||
class KafkaConnection(local):
|
||||
"""
|
||||
A socket connection to a single Kafka broker
|
||||
|
||||
This class is _not_ thread safe. Each call to `send` must be followed
|
||||
by a call to `recv` in order to get the correct response. Eventually,
|
||||
we can do something in here to facilitate multiplexed requests/responses
|
||||
since the Kafka API includes a correlation id.
|
||||
"""A socket connection to a single Kafka broker
|
||||
|
||||
Arguments:
|
||||
host: the host name or IP address of a kafka broker
|
||||
|
51
kafka/future.py
Normal file
51
kafka/future.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from kafka.common import RetriableError, IllegalStateError
|
||||
|
||||
|
||||
class Future(object):
|
||||
def __init__(self):
|
||||
self.is_done = False
|
||||
self.value = None
|
||||
self.exception = None
|
||||
self._callbacks = []
|
||||
self._errbacks = []
|
||||
|
||||
def succeeded(self):
|
||||
return self.is_done and not self.exception
|
||||
|
||||
def failed(self):
|
||||
return self.is_done and self.exception
|
||||
|
||||
def retriable(self):
|
||||
return isinstance(self.exception, RetriableError)
|
||||
|
||||
def success(self, value):
|
||||
if self.is_done:
|
||||
raise IllegalStateError('Invalid attempt to complete a request future which is already complete')
|
||||
self.value = value
|
||||
self.is_done = True
|
||||
for f in self._callbacks:
|
||||
f(value)
|
||||
return self
|
||||
|
||||
def failure(self, e):
|
||||
if self.is_done:
|
||||
raise IllegalStateError('Invalid attempt to complete a request future which is already complete')
|
||||
self.exception = e
|
||||
self.is_done = True
|
||||
for f in self._errbacks:
|
||||
f(e)
|
||||
return self
|
||||
|
||||
def add_callback(self, f):
|
||||
if self.is_done and not self.exception:
|
||||
f(self.value)
|
||||
else:
|
||||
self._callbacks.append(f)
|
||||
return self
|
||||
|
||||
def add_errback(self, f):
|
||||
if self.is_done and self.exception:
|
||||
f(self.exception)
|
||||
else:
|
||||
self._errbacks.append(f)
|
||||
return self
|
@@ -14,6 +14,7 @@ from kafka.common import (
|
||||
KafkaTimeoutError, ConnectionError
|
||||
)
|
||||
from kafka.conn import KafkaConnection
|
||||
from kafka.future import Future
|
||||
from kafka.protocol import KafkaProtocol, create_message
|
||||
from kafka.protocol.metadata import MetadataResponse
|
||||
|
||||
@@ -23,6 +24,17 @@ NO_ERROR = 0
|
||||
UNKNOWN_TOPIC_OR_PARTITION = 3
|
||||
NO_LEADER = 5
|
||||
|
||||
|
||||
def mock_conn(conn, success=True):
|
||||
mocked = MagicMock()
|
||||
mocked.connected.return_value = True
|
||||
if success:
|
||||
mocked.send.return_value = Future().success(True)
|
||||
else:
|
||||
mocked.send.return_value = Future().failure(Exception())
|
||||
conn.return_value = mocked
|
||||
|
||||
|
||||
class TestKafkaClient(unittest.TestCase):
|
||||
def test_init_with_list(self):
|
||||
with patch.object(KafkaClient, 'load_metadata_for_topics'):
|
||||
@@ -48,22 +60,20 @@ class TestKafkaClient(unittest.TestCase):
|
||||
sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]),
|
||||
sorted(client.hosts))
|
||||
|
||||
def test_send_broker_unaware_request_fail(self):
|
||||
@patch.object(KafkaClient, '_get_conn')
|
||||
@patch.object(KafkaClient, 'load_metadata_for_topics')
|
||||
def test_send_broker_unaware_request_fail(self, load_metadata, conn):
|
||||
mocked_conns = {
|
||||
('kafka01', 9092): MagicMock(),
|
||||
('kafka02', 9092): MagicMock()
|
||||
}
|
||||
|
||||
# inject KafkaConnection side effects
|
||||
mocked_conns[('kafka01', 9092)].send.return_value = None
|
||||
mocked_conns[('kafka02', 9092)].send.return_value = None
|
||||
for val in mocked_conns.values():
|
||||
mock_conn(val, success=False)
|
||||
|
||||
def mock_get_conn(host, port):
|
||||
return mocked_conns[(host, port)]
|
||||
conn.side_effect = mock_get_conn
|
||||
|
||||
# patch to avoid making requests before we want it
|
||||
with patch.object(KafkaClient, 'load_metadata_for_topics'):
|
||||
with patch.object(KafkaClient, '_get_conn', side_effect=mock_get_conn):
|
||||
client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092'])
|
||||
|
||||
req = KafkaProtocol.encode_metadata_request()
|
||||
@@ -82,9 +92,11 @@ class TestKafkaClient(unittest.TestCase):
|
||||
('kafka03', 9092): MagicMock()
|
||||
}
|
||||
# inject KafkaConnection side effects
|
||||
mocked_conns[('kafka01', 9092)].send.return_value = None
|
||||
mocked_conns[('kafka02', 9092)].recv.return_value = 'valid response'
|
||||
mocked_conns[('kafka03', 9092)].send.return_value = None
|
||||
mock_conn(mocked_conns[('kafka01', 9092)], success=False)
|
||||
mock_conn(mocked_conns[('kafka03', 9092)], success=False)
|
||||
future = Future()
|
||||
mocked_conns[('kafka02', 9092)].send.return_value = future
|
||||
mocked_conns[('kafka02', 9092)].recv.side_effect = lambda: future.success('valid response')
|
||||
|
||||
def mock_get_conn(host, port):
|
||||
return mocked_conns[(host, port)]
|
||||
@@ -101,11 +113,11 @@ class TestKafkaClient(unittest.TestCase):
|
||||
self.assertEqual('valid response', resp)
|
||||
mocked_conns[('kafka02', 9092)].recv.assert_called_once_with()
|
||||
|
||||
@patch('kafka.client.BrokerConnection')
|
||||
@patch('kafka.client.KafkaClient._get_conn')
|
||||
@patch('kafka.client.KafkaProtocol')
|
||||
def test_load_metadata(self, protocol, conn):
|
||||
|
||||
conn.recv.return_value = 'response' # anything but None
|
||||
mock_conn(conn)
|
||||
|
||||
brokers = [
|
||||
BrokerMetadata(0, 'broker_1', 4567),
|
||||
@@ -151,11 +163,11 @@ class TestKafkaClient(unittest.TestCase):
|
||||
# This should not raise
|
||||
client.load_metadata_for_topics('topic_no_leader')
|
||||
|
||||
@patch('kafka.client.BrokerConnection')
|
||||
@patch('kafka.client.KafkaClient._get_conn')
|
||||
@patch('kafka.client.KafkaProtocol')
|
||||
def test_has_metadata_for_topic(self, protocol, conn):
|
||||
|
||||
conn.recv.return_value = 'response' # anything but None
|
||||
mock_conn(conn)
|
||||
|
||||
brokers = [
|
||||
BrokerMetadata(0, 'broker_1', 4567),
|
||||
@@ -181,11 +193,11 @@ class TestKafkaClient(unittest.TestCase):
|
||||
# Topic with partition metadata, but no leaders return True
|
||||
self.assertTrue(client.has_metadata_for_topic('topic_noleaders'))
|
||||
|
||||
@patch('kafka.client.BrokerConnection')
|
||||
@patch('kafka.client.KafkaClient._get_conn')
|
||||
@patch('kafka.client.KafkaProtocol.decode_metadata_response')
|
||||
def test_ensure_topic_exists(self, decode_metadata_response, conn):
|
||||
|
||||
conn.recv.return_value = 'response' # anything but None
|
||||
mock_conn(conn)
|
||||
|
||||
brokers = [
|
||||
BrokerMetadata(0, 'broker_1', 4567),
|
||||
@@ -213,12 +225,12 @@ class TestKafkaClient(unittest.TestCase):
|
||||
# This should not raise
|
||||
client.ensure_topic_exists('topic_noleaders', timeout=1)
|
||||
|
||||
@patch('kafka.client.BrokerConnection')
|
||||
@patch('kafka.client.KafkaClient._get_conn')
|
||||
@patch('kafka.client.KafkaProtocol')
|
||||
def test_get_leader_for_partitions_reloads_metadata(self, protocol, conn):
|
||||
"Get leader for partitions reload metadata if it is not available"
|
||||
|
||||
conn.recv.return_value = 'response' # anything but None
|
||||
mock_conn(conn)
|
||||
|
||||
brokers = [
|
||||
BrokerMetadata(0, 'broker_1', 4567),
|
||||
@@ -251,11 +263,11 @@ class TestKafkaClient(unittest.TestCase):
|
||||
TopicAndPartition('topic_one_partition', 0): brokers[0]},
|
||||
client.topics_to_brokers)
|
||||
|
||||
@patch('kafka.client.BrokerConnection')
|
||||
@patch('kafka.client.KafkaClient._get_conn')
|
||||
@patch('kafka.client.KafkaProtocol')
|
||||
def test_get_leader_for_unassigned_partitions(self, protocol, conn):
|
||||
|
||||
conn.recv.return_value = 'response' # anything but None
|
||||
mock_conn(conn)
|
||||
|
||||
brokers = [
|
||||
BrokerMetadata(0, 'broker_1', 4567),
|
||||
@@ -278,11 +290,11 @@ class TestKafkaClient(unittest.TestCase):
|
||||
with self.assertRaises(UnknownTopicOrPartitionError):
|
||||
client._get_leader_for_partition('topic_unknown', 0)
|
||||
|
||||
@patch('kafka.client.BrokerConnection')
|
||||
@patch('kafka.client.KafkaClient._get_conn')
|
||||
@patch('kafka.client.KafkaProtocol')
|
||||
def test_get_leader_exceptions_when_noleader(self, protocol, conn):
|
||||
|
||||
conn.recv.return_value = 'response' # anything but None
|
||||
mock_conn(conn)
|
||||
|
||||
brokers = [
|
||||
BrokerMetadata(0, 'broker_1', 4567),
|
||||
@@ -325,10 +337,10 @@ class TestKafkaClient(unittest.TestCase):
|
||||
self.assertEqual(brokers[0], client._get_leader_for_partition('topic_noleader', 0))
|
||||
self.assertEqual(brokers[1], client._get_leader_for_partition('topic_noleader', 1))
|
||||
|
||||
@patch('kafka.client.BrokerConnection')
|
||||
@patch.object(KafkaClient, '_get_conn')
|
||||
@patch('kafka.client.KafkaProtocol')
|
||||
def test_send_produce_request_raises_when_noleader(self, protocol, conn):
|
||||
conn.recv.return_value = 'response' # anything but None
|
||||
mock_conn(conn)
|
||||
|
||||
brokers = [
|
||||
BrokerMetadata(0, 'broker_1', 4567),
|
||||
@@ -352,11 +364,11 @@ class TestKafkaClient(unittest.TestCase):
|
||||
with self.assertRaises(LeaderNotAvailableError):
|
||||
client.send_produce_request(requests)
|
||||
|
||||
@patch('kafka.client.BrokerConnection')
|
||||
@patch('kafka.client.KafkaClient._get_conn')
|
||||
@patch('kafka.client.KafkaProtocol')
|
||||
def test_send_produce_request_raises_when_topic_unknown(self, protocol, conn):
|
||||
|
||||
conn.recv.return_value = 'response' # anything but None
|
||||
mock_conn(conn)
|
||||
|
||||
brokers = [
|
||||
BrokerMetadata(0, 'broker_1', 4567),
|
||||
|
Reference in New Issue
Block a user