Merge pull request #515 from dpkp/kafka_producer

KafkaProducer
This commit is contained in:
Dana Powers
2016-01-24 18:36:46 -08:00
17 changed files with 2152 additions and 302 deletions

View File

@@ -50,7 +50,34 @@ for examples.
KafkaProducer
*************
<`in progress - see SimpleProducer for legacy producer implementation`>
KafkaProducer is a high-level, asynchronous message producer. The class is
intended to operate as similarly as possible to the official java client.
See `ReadTheDocs <http://kafka-python.readthedocs.org/en/master/apidoc/KafkaProducer.html>`_
for more details.
>>> from kafka import KafkaProducer
>>> producer = KafkaProducer(bootstrap_servers='localhost:1234')
>>> producer.send('foobar', b'some_message_bytes')
>>> # Blocking send
>>> producer.send('foobar', b'another_message').get(timeout=60)
>>> # Use a key for hashed-partitioning
>>> producer.send('foobar', key=b'foo', value=b'bar')
>>> # Serialize json messages
>>> import json
>>> producer = KafkaProducer(value_serializer=json.loads)
>>> producer.send('fizzbuzz', {'foo': 'bar'})
>>> # Serialize string keys
>>> producer = KafkaProducer(key_serializer=str.encode)
>>> producer.send('flipflap', key='ping', value=b'1234')
>>> # Compress messages
>>> producer = KafkaProducer(compression_type='gzip')
>>> for i in range(1000):
... producer.send('foobar', b'msg %d' % i)
Protocol

View File

@@ -1,4 +1,5 @@
KafkaProducer
=============
<unreleased> See :class:`kafka.producer.SimpleProducer`
.. autoclass:: kafka.KafkaProducer
:members:

View File

@@ -5,6 +5,7 @@ __license__ = 'Apache License 2.0'
__copyright__ = 'Copyright 2016 Dana Powers, David Arthur, and Contributors'
from kafka.consumer import KafkaConsumer
from kafka.producer import KafkaProducer
from kafka.conn import BrokerConnection
from kafka.protocol import (
create_message, create_gzip_message, create_snappy_message)
@@ -28,7 +29,7 @@ class KafkaClient(SimpleClient):
__all__ = [
'KafkaConsumer', 'KafkaClient', 'BrokerConnection',
'KafkaConsumer', 'KafkaProducer', 'KafkaClient', 'BrokerConnection',
'SimpleClient', 'SimpleProducer', 'KeyedProducer',
'RoundRobinPartitioner', 'HashedPartitioner',
'create_message', 'create_gzip_message', 'create_snappy_message',

View File

@@ -0,0 +1,23 @@
import random
from .hashed import murmur2
class DefaultPartitioner(object):
"""Default partitioner.
Hashes key to partition using murmur2 hashing (from java client)
If key is None, selects partition randomly from available,
or from all partitions if none are currently available
"""
@classmethod
def __call__(cls, key, all_partitions, available):
if key is None:
if available:
return random.choice(available)
return random.choice(all_partitions)
idx = murmur2(key)
idx &= 0x7fffffff
idx %= len(all_partitions)
return all_partitions[idx]

View File

@@ -1,6 +1,8 @@
from .kafka import KafkaProducer
from .simple import SimpleProducer
from .keyed import KeyedProducer
__all__ = [
'SimpleProducer', 'KeyedProducer'
'KafkaProducer',
'SimpleProducer', 'KeyedProducer' # deprecated
]

388
kafka/producer/buffer.py Normal file
View File

@@ -0,0 +1,388 @@
from __future__ import absolute_import
import collections
import io
import threading
import time
from ..codec import (has_gzip, has_snappy,
gzip_encode, snappy_encode)
from ..protocol.types import Int32, Int64
from ..protocol.message import MessageSet, Message
import kafka.common as Errors
class MessageSetBuffer(object):
"""Wrap a buffer for writing MessageSet batches.
Arguments:
buf (IO stream): a buffer for writing data. Typically BytesIO.
batch_size (int): maximum number of bytes to write to the buffer.
Keyword Arguments:
compression_type ('gzip', 'snappy', None): compress messages before
publishing. Default: None.
"""
_COMPRESSORS = {
'gzip': (has_gzip, gzip_encode, Message.CODEC_GZIP),
'snappy': (has_snappy, snappy_encode, Message.CODEC_SNAPPY),
}
def __init__(self, buf, batch_size, compression_type=None):
assert batch_size > 0, 'batch_size must be > 0'
if compression_type is not None:
assert compression_type in self._COMPRESSORS, 'Unrecognized compression type'
checker, encoder, attributes = self._COMPRESSORS[compression_type]
assert checker(), 'Compression Libraries Not Found'
self._compressor = encoder
self._compression_attributes = attributes
else:
self._compressor = None
self._compression_attributes = None
self._buffer = buf
# Init MessageSetSize to 0 -- update on close
self._buffer.seek(0)
self._buffer.write(Int32.encode(0))
self._batch_size = batch_size
self._closed = False
self._messages = 0
def append(self, offset, message):
"""Apend a Message to the MessageSet.
Arguments:
offset (int): offset of the message
message (Message or bytes): message struct or encoded bytes
"""
if isinstance(message, Message):
encoded = message.encode()
else:
encoded = bytes(message)
msg = Int64.encode(offset) + Int32.encode(len(encoded)) + encoded
self._buffer.write(msg)
self._messages += 1
def has_room_for(self, key, value):
if self._closed:
return False
if not self._messages:
return True
needed_bytes = MessageSet.HEADER_SIZE + Message.HEADER_SIZE
if key is not None:
needed_bytes += len(key)
if value is not None:
needed_bytes += len(value)
return self._buffer.tell() + needed_bytes < self._batch_size
def is_full(self):
if self._closed:
return True
return self._buffer.tell() >= self._batch_size
def close(self):
if self._compressor:
# TODO: avoid copies with bytearray / memoryview
self._buffer.seek(4)
msg = Message(self._compressor(self._buffer.read()),
attributes=self._compression_attributes)
encoded = msg.encode()
self._buffer.seek(4)
self._buffer.write(Int64.encode(0)) # offset 0 for wrapper msg
self._buffer.write(Int32.encode(len(encoded)))
self._buffer.write(encoded)
# Update the message set size, and return ready for full read()
size = self._buffer.tell() - 4
self._buffer.seek(0)
self._buffer.write(Int32.encode(size))
self._buffer.seek(0)
self._closed = True
def size_in_bytes(self):
return self._buffer.tell()
def buffer(self):
return self._buffer
class SimpleBufferPool(object):
"""A simple pool of BytesIO objects with a weak memory ceiling."""
def __init__(self, memory, poolable_size):
"""Create a new buffer pool.
Arguments:
memory (int): maximum memory that this buffer pool can allocate
poolable_size (int): memory size per buffer to cache in the free
list rather than deallocating
"""
self._poolable_size = poolable_size
self._lock = threading.RLock()
buffers = int(memory / poolable_size)
self._free = collections.deque([io.BytesIO() for _ in range(buffers)])
self._waiters = collections.deque()
#self.metrics = metrics;
#self.waitTime = this.metrics.sensor("bufferpool-wait-time");
#MetricName metricName = metrics.metricName("bufferpool-wait-ratio", metricGrpName, "The fraction of time an appender waits for space allocation.");
#this.waitTime.add(metricName, new Rate(TimeUnit.NANOSECONDS));
def allocate(self, max_time_to_block_ms):
"""
Allocate a buffer of the given size. This method blocks if there is not
enough memory and the buffer pool is configured with blocking mode.
Arguments:
max_time_to_block_ms (int): The maximum time in milliseconds to
block for buffer memory to be available
Returns:
io.BytesIO
"""
with self._lock:
# check if we have a free buffer of the right size pooled
if self._free:
return self._free.popleft()
else:
# we are out of buffers and will have to block
buf = None
more_memory = threading.Condition(self._lock)
self._waiters.append(more_memory)
# loop over and over until we have a buffer or have reserved
# enough memory to allocate one
while buf is None:
start_wait = time.time()
if not more_memory.wait(max_time_to_block_ms / 1000.0):
raise Errors.KafkaTimeoutError(
"Failed to allocate memory within the configured"
" max blocking time")
end_wait = time.time()
#this.waitTime.record(endWait - startWait, time.milliseconds());
if self._free:
buf = self._free.popleft()
# remove the condition for this thread to let the next thread
# in line start getting memory
removed = self._waiters.popleft()
assert removed is more_memory, 'Wrong condition'
# signal any additional waiters if there is more memory left
# over for them
if self._free and self._waiters:
self._waiters[0].notify()
# unlock and return the buffer
return buf
def deallocate(self, buf):
"""
Return buffers to the pool. If they are of the poolable size add them
to the free list, otherwise just mark the memory as free.
Arguments:
buffer_ (io.BytesIO): The buffer to return
"""
with self._lock:
capacity = buf.seek(0, 2)
# free extra memory if needed
if capacity > self._poolable_size:
# BytesIO (cpython) only frees memory if 2x reduction or more
trunc_to = int(min(capacity / 2, self._poolable_size))
buf.truncate(trunc_to)
buf.seek(0)
#buf.write(bytearray(12))
#buf.seek(0)
self._free.append(buf)
if self._waiters:
self._waiters[0].notify()
def queued(self):
"""The number of threads blocked waiting on memory."""
with self._lock:
return len(self._waiters)
'''
class BufferPool(object):
"""
A pool of ByteBuffers kept under a given memory limit. This class is fairly
specific to the needs of the producer. In particular it has the following
properties:
* There is a special "poolable size" and buffers of this size are kept in a
free list and recycled
* It is fair. That is all memory is given to the longest waiting thread
until it has sufficient memory. This prevents starvation or deadlock when
a thread asks for a large chunk of memory and needs to block until
multiple buffers are deallocated.
"""
def __init__(self, memory, poolable_size):
"""Create a new buffer pool.
Arguments:
memory (int): maximum memory that this buffer pool can allocate
poolable_size (int): memory size per buffer to cache in the free
list rather than deallocating
"""
self._poolable_size = poolable_size
self._lock = threading.RLock()
self._free = collections.deque()
self._waiters = collections.deque()
self._total_memory = memory
self._available_memory = memory
#self.metrics = metrics;
#self.waitTime = this.metrics.sensor("bufferpool-wait-time");
#MetricName metricName = metrics.metricName("bufferpool-wait-ratio", metricGrpName, "The fraction of time an appender waits for space allocation.");
#this.waitTime.add(metricName, new Rate(TimeUnit.NANOSECONDS));
def allocate(self, size, max_time_to_block_ms):
"""
Allocate a buffer of the given size. This method blocks if there is not
enough memory and the buffer pool is configured with blocking mode.
Arguments:
size (int): The buffer size to allocate in bytes
max_time_to_block_ms (int): The maximum time in milliseconds to
block for buffer memory to be available
Returns:
buffer
Raises:
InterruptedException If the thread is interrupted while blocked
IllegalArgumentException if size is larger than the total memory
controlled by the pool (and hence we would block forever)
"""
assert size <= self._total_memory, (
"Attempt to allocate %d bytes, but there is a hard limit of %d on"
" memory allocations." % (size, self._total_memory))
with self._lock:
# check if we have a free buffer of the right size pooled
if (size == self._poolable_size and len(self._free) > 0):
return self._free.popleft()
# now check if the request is immediately satisfiable with the
# memory on hand or if we need to block
free_list_size = len(self._free) * self._poolable_size
if self._available_memory + free_list_size >= size:
# we have enough unallocated or pooled memory to immediately
# satisfy the request
self._free_up(size)
self._available_memory -= size
raise NotImplementedError()
#return ByteBuffer.allocate(size)
else:
# we are out of memory and will have to block
accumulated = 0
buf = None
more_memory = threading.Condition(self._lock)
self._waiters.append(more_memory)
# loop over and over until we have a buffer or have reserved
# enough memory to allocate one
while (accumulated < size):
start_wait = time.time()
if not more_memory.wait(max_time_to_block_ms / 1000.0):
raise Errors.KafkaTimeoutError(
"Failed to allocate memory within the configured"
" max blocking time")
end_wait = time.time()
#this.waitTime.record(endWait - startWait, time.milliseconds());
# check if we can satisfy this request from the free list,
# otherwise allocate memory
if (accumulated == 0
and size == self._poolable_size
and self._free):
# just grab a buffer from the free list
buf = self._free.popleft()
accumulated = size
else:
# we'll need to allocate memory, but we may only get
# part of what we need on this iteration
self._free_up(size - accumulated)
got = min(size - accumulated, self._available_memory)
self._available_memory -= got
accumulated += got
# remove the condition for this thread to let the next thread
# in line start getting memory
removed = self._waiters.popleft()
assert removed is more_memory, 'Wrong condition'
# signal any additional waiters if there is more memory left
# over for them
if (self._available_memory > 0 or len(self._free) > 0):
if len(self._waiters) > 0:
self._waiters[0].notify()
# unlock and return the buffer
if buf is None:
raise NotImplementedError()
#return ByteBuffer.allocate(size)
else:
return buf
def _free_up(self, size):
"""
Attempt to ensure we have at least the requested number of bytes of
memory for allocation by deallocating pooled buffers (if needed)
"""
while self._free and self._available_memory < size:
self._available_memory += self._free.pop().capacity
def deallocate(self, buffer_, size=None):
"""
Return buffers to the pool. If they are of the poolable size add them
to the free list, otherwise just mark the memory as free.
Arguments:
buffer (io.BytesIO): The buffer to return
size (int): The size of the buffer to mark as deallocated, note
that this maybe smaller than buffer.capacity since the buffer
may re-allocate itself during in-place compression
"""
with self._lock:
if size is None:
size = buffer_.capacity
if (size == self._poolable_size and size == buffer_.capacity):
buffer_.seek(0)
buffer_.truncate()
self._free.append(buffer_)
else:
self._available_memory += size
if self._waiters:
more_mem = self._waiters[0]
more_mem.notify()
def available_memory(self):
"""The total free memory both unallocated and in the free list."""
with self._lock:
return self._available_memory + len(self._free) * self._poolable_size
def unallocated_memory(self):
"""Get the unallocated memory (not in the free list or in use)."""
with self._lock:
return self._available_memory
def queued(self):
"""The number of threads blocked waiting on memory."""
with self._lock:
return len(self._waiters)
def poolable_size(self):
"""The buffer size that will be retained in the free list after use."""
return self._poolable_size
def total_memory(self):
"""The total memory managed by this pool."""
return self._total_memory
'''

66
kafka/producer/future.py Normal file
View File

@@ -0,0 +1,66 @@
from __future__ import absolute_import
import collections
import threading
from ..future import Future
import kafka.common as Errors
class FutureProduceResult(Future):
def __init__(self, topic_partition):
super(FutureProduceResult, self).__init__()
self.topic_partition = topic_partition
self._latch = threading.Event()
def success(self, value):
ret = super(FutureProduceResult, self).success(value)
self._latch.set()
return ret
def failure(self, error):
ret = super(FutureProduceResult, self).failure(error)
self._latch.set()
return ret
def await(self, timeout=None):
return self._latch.wait(timeout)
class FutureRecordMetadata(Future):
def __init__(self, produce_future, relative_offset):
super(FutureRecordMetadata, self).__init__()
self._produce_future = produce_future
self.relative_offset = relative_offset
produce_future.add_callback(self._produce_success)
produce_future.add_errback(self.failure)
def _produce_success(self, base_offset):
self.success(RecordMetadata(self._produce_future.topic_partition,
base_offset, self.relative_offset))
def get(self, timeout=None):
if not self.is_done and not self._produce_future.await(timeout):
raise Errors.KafkaTimeoutError(
"Timeout after waiting for %s secs." % timeout)
assert self.is_done
if self.failed():
raise self.exception # pylint: disable-msg=raising-bad-type
return self.value
class RecordMetadata(collections.namedtuple(
'RecordMetadata', 'topic partition topic_partition offset')):
def __new__(cls, tp, base_offset, relative_offset=None):
offset = base_offset
if relative_offset is not None and base_offset != -1:
offset += relative_offset
return super(RecordMetadata, cls).__new__(cls, tp.topic, tp.partition, tp, offset)
def __str__(self):
return 'RecordMetadata(topic=%s, partition=%s, offset=%s)' % (
self.topic, self.partition, self.offset)
def __repr__(self):
return str(self)

496
kafka/producer/kafka.py Normal file
View File

@@ -0,0 +1,496 @@
from __future__ import absolute_import
import atexit
import copy
import logging
import signal
import threading
import time
from ..client_async import KafkaClient
from ..common import TopicPartition
from ..partitioner.default import DefaultPartitioner
from ..protocol.message import Message, MessageSet
from .future import FutureRecordMetadata, FutureProduceResult
from .record_accumulator import AtomicInteger, RecordAccumulator
from .sender import Sender
import kafka.common as Errors
log = logging.getLogger(__name__)
PRODUCER_CLIENT_ID_SEQUENCE = AtomicInteger()
class KafkaProducer(object):
"""A Kafka client that publishes records to the Kafka cluster.
The producer is thread safe and sharing a single producer instance across
threads will generally be faster than having multiple instances.
The producer consists of a pool of buffer space that holds records that
haven't yet been transmitted to the server as well as a background I/O
thread that is responsible for turning these records into requests and
transmitting them to the cluster.
The send() method is asynchronous. When called it adds the record to a
buffer of pending record sends and immediately returns. This allows the
producer to batch together individual records for efficiency.
The 'acks' config controls the criteria under which requests are considered
complete. The "all" setting will result in blocking on the full commit of
the record, the slowest but most durable setting.
If the request fails, the producer can automatically retry, unless
'retries' is configured to 0. Enabling retries also opens up the
possibility of duplicates (see the documentation on message
delivery semantics for details:
http://kafka.apache.org/documentation.html#semantics
).
The producer maintains buffers of unsent records for each partition. These
buffers are of a size specified by the 'batch_size' config. Making this
larger can result in more batching, but requires more memory (since we will
generally have one of these buffers for each active partition).
By default a buffer is available to send immediately even if there is
additional unused space in the buffer. However if you want to reduce the
number of requests you can set 'linger_ms' to something greater than 0.
This will instruct the producer to wait up to that number of milliseconds
before sending a request in hope that more records will arrive to fill up
the same batch. This is analogous to Nagle's algorithm in TCP. Note that
records that arrive close together in time will generally batch together
even with linger_ms=0 so under heavy load batching will occur regardless of
the linger configuration; however setting this to something larger than 0
can lead to fewer, more efficient requests when not under maximal load at
the cost of a small amount of latency.
The buffer_memory controls the total amount of memory available to the
producer for buffering. If records are sent faster than they can be
transmitted to the server then this buffer space will be exhausted. When
the buffer space is exhausted additional send calls will block.
The key_serializer and value_serializer instruct how to turn the key and
value objects the user provides into bytes.
Keyword Arguments:
bootstrap_servers: 'host[:port]' string (or list of 'host[:port]'
strings) that the producer should contact to bootstrap initial
cluster metadata. This does not have to be the full node list.
It just needs to have at least one broker that will respond to a
Metadata API Request. Default port is 9092. If no servers are
specified, will default to localhost:9092.
client_id (str): a name for this client. This string is passed in
each request to servers and can be used to identify specific
server-side log entries that correspond to this client.
Default: 'kafka-python-producer-#' (appended with a unique number
per instance)
key_serializer (callable): used to convert user-supplied keys to bytes
If not None, called as f(key), should return bytes. Default: None.
value_serializer (callable): used to convert user-supplied message
values to bytes. If not None, called as f(value), should return
bytes. Default: None.
acks (0, 1, 'all'): The number of acknowledgments the producer requires
the leader to have received before considering a request complete.
This controls the durability of records that are sent. The
following settings are common:
0: Producer will not wait for any acknowledgment from the server
at all. The message will immediately be added to the socket
buffer and considered sent. No guarantee can be made that the
server has received the record in this case, and the retries
configuration will not take effect (as the client won't
generally know of any failures). The offset given back for each
record will always be set to -1.
1: The broker leader will write the record to its local log but
will respond without awaiting full acknowledgement from all
followers. In this case should the leader fail immediately
after acknowledging the record but before the followers have
replicated it then the record will be lost.
all: The broker leader will wait for the full set of in-sync
replicas to acknowledge the record. This guarantees that the
record will not be lost as long as at least one in-sync replica
remains alive. This is the strongest available guarantee.
If unset, defaults to acks=1.
compression_type (str): The compression type for all data generated by
the producer. Valid values are 'gzip', 'snappy', or None.
Compression is of full batches of data, so the efficacy of batching
will also impact the compression ratio (more batching means better
compression). Default: None.
retries (int): Setting a value greater than zero will cause the client
to resend any record whose send fails with a potentially transient
error. Note that this retry is no different than if the client
resent the record upon receiving the error. Allowing retries will
potentially change the ordering of records because if two records
are sent to a single partition, and the first fails and is retried
but the second succeeds, then the second record may appear first.
Default: 0.
batch_size (int): Requests sent to brokers will contain multiple
batches, one for each partition with data available to be sent.
A small batch size will make batching less common and may reduce
throughput (a batch size of zero will disable batching entirely).
Default: 16384
linger_ms (int): The producer groups together any records that arrive
in between request transmissions into a single batched request.
Normally this occurs only under load when records arrive faster
than they can be sent out. However in some circumstances the client
may want to reduce the number of requests even under moderate load.
This setting accomplishes this by adding a small amount of
artificial delay; that is, rather than immediately sending out a
record the producer will wait for up to the given delay to allow
other records to be sent so that the sends can be batched together.
This can be thought of as analogous to Nagle's algorithm in TCP.
This setting gives the upper bound on the delay for batching: once
we get batch_size worth of records for a partition it will be sent
immediately regardless of this setting, however if we have fewer
than this many bytes accumulated for this partition we will
'linger' for the specified time waiting for more records to show
up. This setting defaults to 0 (i.e. no delay). Setting linger_ms=5
would have the effect of reducing the number of requests sent but
would add up to 5ms of latency to records sent in the absense of
load. Default: 0.
partitioner (callable): Callable used to determine which partition
each message is assigned to. Called (after key serialization):
partitioner(key_bytes, all_partitions, available_partitions).
The default partitioner implementation hashes each non-None key
using the same murmur2 algorithm as the java client so that
messages with the same key are assigned to the same partition.
When a key is None, the message is delivered to a random partition
(filtered to partitions with available leaders only, if possible).
buffer_memory (int): The total bytes of memory the producer should use
to buffer records waiting to be sent to the server. If records are
sent faster than they can be delivered to the server the producer
will block up to max_block_ms, raising an exception on timeout.
In the current implementation, this setting is an approximation.
Default: 33554432 (32MB)
max_block_ms (int): Number of milliseconds to block during send()
when attempting to allocate additional memory before raising an
exception. Default: 60000.
max_request_size (int): The maximum size of a request. This is also
effectively a cap on the maximum record size. Note that the server
has its own cap on record size which may be different from this.
This setting will limit the number of record batches the producer
will send in a single request to avoid sending huge requests.
Default: 1048576.
metadata_max_age_ms (int): The period of time in milliseconds after
which we force a refresh of metadata even if we haven't seen any
partition leadership changes to proactively discover any new
brokers or partitions. Default: 300000
retry_backoff_ms (int): Milliseconds to backoff when retrying on
errors. Default: 100.
request_timeout_ms (int): Client request timeout in milliseconds.
Default: 30000.
receive_buffer_bytes (int): The size of the TCP receive buffer
(SO_RCVBUF) to use when reading data. Default: 32768
send_buffer_bytes (int): The size of the TCP send buffer
(SO_SNDBUF) to use when sending data. Default: 131072
reconnect_backoff_ms (int): The amount of time in milliseconds to
wait before attempting to reconnect to a given host.
Default: 50.
max_in_flight_requests_per_connection (int): Requests are pipelined
to kafka brokers up to this number of maximum requests per
broker connection. Default: 5.
api_version (str): specify which kafka API version to use.
If set to 'auto', will attempt to infer the broker version by
probing various APIs. Default: auto
Note:
Configuration parameters are described in more detail at
https://kafka.apache.org/090/configuration.html#producerconfigs
"""
_DEFAULT_CONFIG = {
'bootstrap_servers': 'localhost',
'client_id': None,
'key_serializer': None,
'value_serializer': None,
'acks': 1,
'compression_type': None,
'retries': 0,
'batch_size': 16384,
'linger_ms': 0,
'partitioner': DefaultPartitioner(),
'buffer_memory': 33554432,
'connections_max_idle_ms': 600000, # not implemented yet
'max_block_ms': 60000,
'max_request_size': 1048576,
'metadata_max_age_ms': 300000,
'retry_backoff_ms': 100,
'request_timeout_ms': 30000,
'receive_buffer_bytes': 32768,
'send_buffer_bytes': 131072,
'reconnect_backoff_ms': 50,
'max_in_flight_requests_per_connection': 5,
'api_version': 'auto',
}
def __init__(self, **configs):
log.debug("Starting the Kafka producer") # trace
self.config = copy.copy(self._DEFAULT_CONFIG)
for key in self.config:
if key in configs:
self.config[key] = configs.pop(key)
# Only check for extra config keys in top-level class
assert not configs, 'Unrecognized configs: %s' % configs
if self.config['client_id'] is None:
self.config['client_id'] = 'kafka-python-producer-%s' % \
PRODUCER_CLIENT_ID_SEQUENCE.increment()
if self.config['acks'] == 'all':
self.config['acks'] = -1
client = KafkaClient(**self.config)
# Check Broker Version if not set explicitly
if self.config['api_version'] == 'auto':
self.config['api_version'] = client.check_version()
assert self.config['api_version'] in ('0.9', '0.8.2', '0.8.1', '0.8.0')
# Convert api_version config to tuple for easy comparisons
self.config['api_version'] = tuple(
map(int, self.config['api_version'].split('.')))
if self.config['compression_type'] == 'lz4':
assert self.config['api_version'] >= (0, 8, 2), 'LZ4 Requires >= Kafka 0.8.2 Brokers'
self._accumulator = RecordAccumulator(**self.config)
self._metadata = client.cluster
self._metadata_lock = threading.Condition()
self._sender = Sender(client, self._metadata, self._metadata_lock,
self._accumulator, **self.config)
self._sender.daemon = True
self._sender.start()
self._closed = False
atexit.register(self.close, timeout=0)
log.debug("Kafka producer started")
def __del__(self):
self.close(timeout=0)
def close(self, timeout=None):
"""Close this producer."""
if self._closed:
log.info('Kafka producer closed')
return
if timeout is None:
timeout = 999999999
assert timeout >= 0
log.info("Closing the Kafka producer with %s secs timeout.", timeout)
#first_exception = AtomicReference() # this will keep track of the first encountered exception
invoked_from_callback = bool(threading.current_thread() is self._sender)
if timeout > 0:
if invoked_from_callback:
log.warning("Overriding close timeout %s secs to 0 in order to"
" prevent useless blocking due to self-join. This"
" means you have incorrectly invoked close with a"
" non-zero timeout from the producer call-back.",
timeout)
else:
# Try to close gracefully.
if self._sender is not None:
self._sender.initiate_close()
self._sender.join(timeout)
if self._sender is not None and self._sender.is_alive():
log.info("Proceeding to force close the producer since pending"
" requests could not be completed within timeout %s.",
timeout)
self._sender.force_close()
# Only join the sender thread when not calling from callback.
if not invoked_from_callback:
self._sender.join()
try:
self.config['key_serializer'].close()
except AttributeError:
pass
try:
self.config['value_serializer'].close()
except AttributeError:
pass
self._closed = True
log.debug("The Kafka producer has closed.")
def partitions_for(self, topic):
"""Returns set of all known partitions for the topic."""
max_wait = self.config['max_block_ms'] / 1000.0
return self._wait_on_metadata(topic, max_wait)
def send(self, topic, value=None, key=None, partition=None):
"""Publish a message to a topic.
Arguments:
topic (str): topic where the message will be published
value (optional): message value. Must be type bytes, or be
serializable to bytes via configured value_serializer. If value
is None, key is required and message acts as a 'delete'.
See kafka compaction documentation for more details:
http://kafka.apache.org/documentation.html#compaction
(compaction requires kafka >= 0.8.1)
partition (int, optional): optionally specify a partition. If not
set, the partition will be selected using the configured
'partitioner'.
key (optional): a key to associate with the message. Can be used to
determine which partition to send the message to. If partition
is None (and producer's partitioner config is left as default),
then messages with the same key will be delivered to the same
partition (but if key is None, partition is chosen randomly).
Must be type bytes, or be serializable to bytes via configured
key_serializer.
Returns:
FutureRecordMetadata: resolves to RecordMetadata
Raises:
KafkaTimeoutError: if unable to fetch topic metadata, or unable
to obtain memory buffer prior to configured max_block_ms
"""
assert value is not None or self.config['api_version'] >= (0, 8, 1), (
'Null messages require kafka >= 0.8.1')
assert not (value is None and key is None), 'Need at least one: key or value'
try:
# first make sure the metadata for the topic is
# available
self._wait_on_metadata(topic, self.config['max_block_ms'] / 1000.0)
key_bytes, value_bytes = self._serialize(topic, key, value)
partition = self._partition(topic, partition, key, value,
key_bytes, value_bytes)
message_size = MessageSet.HEADER_SIZE + Message.HEADER_SIZE
if key_bytes is not None:
message_size += len(key_bytes)
if value_bytes is not None:
message_size += len(value_bytes)
self._ensure_valid_record_size(message_size)
tp = TopicPartition(topic, partition)
log.debug("Sending (key=%s value=%s) to %s", key, value, tp)
result = self._accumulator.append(tp, key_bytes, value_bytes,
self.config['max_block_ms'])
future, batch_is_full, new_batch_created = result
if batch_is_full or new_batch_created:
log.debug("Waking up the sender since %s is either full or"
" getting a new batch", tp)
self._sender.wakeup()
return future
# handling exceptions and record the errors;
# for API exceptions return them in the future,
# for other exceptions raise directly
except Errors.KafkaTimeoutError:
raise
except AssertionError:
raise
except Exception as e:
log.debug("Exception occurred during message send: %s", e)
return FutureRecordMetadata(
FutureProduceResult(TopicPartition(topic, partition)),
-1).failure(e)
def flush(self):
"""
Invoking this method makes all buffered records immediately available
to send (even if linger_ms is greater than 0) and blocks on the
completion of the requests associated with these records. The
post-condition of flush() is that any previously sent record will have
completed (e.g. Future.is_done() == True). A request is considered
completed when either it is successfully acknowledged according to the
'acks' configuration for the producer, or it results in an error.
Other threads can continue sending messages while one thread is blocked
waiting for a flush call to complete; however, no guarantee is made
about the completion of messages sent after the flush call begins.
"""
log.debug("Flushing accumulated records in producer.") # trace
self._accumulator.begin_flush()
self._sender.wakeup()
self._accumulator.await_flush_completion()
def _ensure_valid_record_size(self, size):
"""Validate that the record size isn't too large."""
if size > self.config['max_request_size']:
raise Errors.MessageSizeTooLargeError(
"The message is %d bytes when serialized which is larger than"
" the maximum request size you have configured with the"
" max_request_size configuration" % size)
if size > self.config['buffer_memory']:
raise Errors.MessageSizeTooLargeError(
"The message is %d bytes when serialized which is larger than"
" the total memory buffer you have configured with the"
" buffer_memory configuration." % size)
def _wait_on_metadata(self, topic, max_wait):
"""
Wait for cluster metadata including partitions for the given topic to
be available.
Arguments:
topic (str): topic we want metadata for
max_wait (float): maximum time in secs for waiting on the metadata
Returns:
set: partition ids for the topic
Raises:
TimeoutException: if partitions for topic were not obtained before
specified max_wait timeout
"""
# add topic to metadata topic list if it is not there already.
self._sender.add_topic(topic)
partitions = self._metadata.partitions_for_topic(topic)
if partitions:
return partitions
event = threading.Event()
def event_set(*args):
event.set()
def request_update(self, event):
event.clear()
log.debug("Requesting metadata update for topic %s.", topic)
f = self._metadata.request_update()
f.add_both(event_set)
return f
begin = time.time()
elapsed = 0.0
future = request_update(self, event)
while elapsed < max_wait:
self._sender.wakeup()
event.wait(max_wait - elapsed)
if future.failed():
future = request_update(self, event)
elapsed = time.time() - begin
partitions = self._metadata.partitions_for_topic(topic)
if partitions:
return partitions
else:
raise Errors.KafkaTimeoutError(
"Failed to update metadata after %s secs.", max_wait)
def _serialize(self, topic, key, value):
# pylint: disable-msg=not-callable
if self.config['key_serializer']:
serialized_key = self.config['key_serializer'](key)
else:
serialized_key = key
if self.config['value_serializer']:
serialized_value = self.config['value_serializer'](value)
else:
serialized_value = value
return serialized_key, serialized_value
def _partition(self, topic, partition, key, value,
serialized_key, serialized_value):
if partition is not None:
assert partition >= 0
assert partition in self._metadata.partitions_for_topic(topic), 'Unrecognized partition'
return partition
all_partitions = list(self._metadata.partitions_for_topic(topic))
available = list(self._metadata.available_partitions_for_topic(topic))
return self.config['partitioner'](serialized_key,
all_partitions,
available)

View File

@@ -0,0 +1,500 @@
from __future__ import absolute_import
import collections
import copy
import logging
import threading
import time
import six
from ..common import TopicPartition
from ..protocol.message import Message, MessageSet
from .buffer import MessageSetBuffer, SimpleBufferPool
from .future import FutureRecordMetadata, FutureProduceResult
import kafka.common as Errors
log = logging.getLogger(__name__)
class AtomicInteger(object):
def __init__(self, val=0):
self._lock = threading.Lock()
self._val = val
def increment(self):
with self._lock:
self._val += 1
return self._val
def decrement(self):
with self._lock:
self._val -= 1
return self._val
def get(self):
return self._val
class RecordBatch(object):
def __init__(self, tp, records):
self.record_count = 0
#self.max_record_size = 0 # for metrics only
now = time.time()
#self.created = now # for metrics only
self.drained = None
self.attempts = 0
self.last_attempt = now
self.last_append = now
self.records = records
self.topic_partition = tp
self.produce_future = FutureProduceResult(tp)
self._retry = False
def try_append(self, key, value):
if not self.records.has_room_for(key, value):
return None
self.records.append(self.record_count, Message(value, key=key))
# self.max_record_size = max(self.max_record_size, Record.record_size(key, value)) # for metrics only
self.last_append = time.time()
future = FutureRecordMetadata(self.produce_future, self.record_count)
self.record_count += 1
return future
def done(self, base_offset=None, exception=None):
log.debug("Produced messages to topic-partition %s with base offset"
" %s and error %s.", self.topic_partition, base_offset,
exception) # trace
if exception is None:
self.produce_future.success(base_offset)
else:
self.produce_future.failure(exception)
def maybe_expire(self, request_timeout_ms, linger_ms):
since_append_ms = 1000 * (time.time() - self.last_append)
if ((self.records.is_full() and request_timeout_ms < since_append_ms)
or (request_timeout_ms < (since_append_ms + linger_ms))):
self.records.close()
self.done(-1, Errors.KafkaTimeoutError('Batch Expired'))
return True
return False
def in_retry(self):
return self._retry
def set_retry(self):
self._retry = True
def __str__(self):
return 'RecordBatch(topic_partition=%s, record_count=%d)' % (
self.topic_partition, self.record_count)
class RecordAccumulator(object):
"""
This class maintains a dequeue per TopicPartition that accumulates messages
into MessageSets to be sent to the server.
The accumulator attempts to bound memory use, and append calls will block
when that memory is exhausted.
Keyword Arguments:
batch_size (int): Requests sent to brokers will contain multiple
batches, one for each partition with data available to be sent.
A small batch size will make batching less common and may reduce
throughput (a batch size of zero will disable batching entirely).
Default: 16384
buffer_memory (int): The total bytes of memory the producer should use
to buffer records waiting to be sent to the server. If records are
sent faster than they can be delivered to the server the producer
will block up to max_block_ms, raising an exception on timeout.
In the current implementation, this setting is an approximation.
Default: 33554432 (32MB)
compression_type (str): The compression type for all data generated by
the producer. Valid values are 'gzip', 'snappy', or None.
Compression is of full batches of data, so the efficacy of batching
will also impact the compression ratio (more batching means better
compression). Default: None.
linger_ms (int): An artificial delay time to add before declaring a
messageset (that isn't full) ready for sending. This allows
time for more records to arrive. Setting a non-zero linger_ms
will trade off some latency for potentially better throughput
due to more batching (and hence fewer, larger requests).
Default: 0
retry_backoff_ms (int): An artificial delay time to retry the
produce request upon receiving an error. This avoids exhausting
all retries in a short period of time. Default: 100
"""
_DEFAULT_CONFIG = {
'buffer_memory': 33554432,
'batch_size': 16384,
'compression_type': None,
'linger_ms': 0,
'retry_backoff_ms': 100,
}
def __init__(self, **configs):
self.config = copy.copy(self._DEFAULT_CONFIG)
for key in self.config:
if key in configs:
self.config[key] = configs.pop(key)
self._closed = False
self._drain_index = 0
self._flushes_in_progress = AtomicInteger()
self._appends_in_progress = AtomicInteger()
self._batches = collections.defaultdict(collections.deque) # TopicPartition: [RecordBatch]
self._tp_locks = {None: threading.Lock()} # TopicPartition: Lock, plus a lock to add entries
self._free = SimpleBufferPool(self.config['buffer_memory'],
self.config['batch_size'])
self._incomplete = IncompleteRecordBatches()
def append(self, tp, key, value, max_time_to_block_ms):
"""Add a record to the accumulator, return the append result.
The append result will contain the future metadata, and flag for
whether the appended batch is full or a new batch is created
Arguments:
tp (TopicPartition): The topic/partition to which this record is
being sent
key (bytes): The key for the record
value (bytes): The value for the record
max_time_to_block_ms (int): The maximum time in milliseconds to
block for buffer memory to be available
Returns:
tuple: (future, batch_is_full, new_batch_created)
"""
assert isinstance(tp, TopicPartition), 'not TopicPartition'
assert not self._closed, 'RecordAccumulator is closed'
# We keep track of the number of appending thread to make sure we do not miss batches in
# abortIncompleteBatches().
self._appends_in_progress.increment()
try:
if tp not in self._tp_locks:
with self._tp_locks[None]:
if tp not in self._tp_locks:
self._tp_locks[tp] = threading.Lock()
with self._tp_locks[tp]:
# check if we have an in-progress batch
dq = self._batches[tp]
if dq:
last = dq[-1]
future = last.try_append(key, value)
if future is not None:
batch_is_full = len(dq) > 1 or last.records.is_full()
return future, batch_is_full, False
# we don't have an in-progress record batch try to allocate a new batch
message_size = MessageSet.HEADER_SIZE + Message.HEADER_SIZE
if key is not None:
message_size += len(key)
if value is not None:
message_size += len(value)
assert message_size <= self.config['buffer_memory'], 'message too big'
size = max(self.config['batch_size'], message_size)
log.debug("Allocating a new %d byte message buffer for %s", size, tp) # trace
buf = self._free.allocate(max_time_to_block_ms)
with self._tp_locks[tp]:
# Need to check if producer is closed again after grabbing the
# dequeue lock.
assert not self._closed, 'RecordAccumulator is closed'
if dq:
last = dq[-1]
future = last.try_append(key, value)
if future is not None:
# Somebody else found us a batch, return the one we
# waited for! Hopefully this doesn't happen often...
self._free.deallocate(buf)
batch_is_full = len(dq) > 1 or last.records.is_full()
return future, batch_is_full, False
records = MessageSetBuffer(buf, self.config['batch_size'],
self.config['compression_type'])
batch = RecordBatch(tp, records)
future = batch.try_append(key, value)
if not future:
raise Exception()
dq.append(batch)
self._incomplete.add(batch)
batch_is_full = len(dq) > 1 or batch.records.is_full()
return future, batch_is_full, True
finally:
self._appends_in_progress.decrement()
def abort_expired_batches(self, request_timeout_ms, cluster):
"""Abort the batches that have been sitting in RecordAccumulator for
more than the configured request_timeout due to metadata being
unavailable.
Arguments:
request_timeout_ms (int): milliseconds to timeout
cluster (ClusterMetadata): current metadata for kafka cluster
Returns:
list of RecordBatch that were expired
"""
expired_batches = []
count = 0
for tp, dq in six.iteritems(self._batches):
assert tp in self._tp_locks, 'TopicPartition not in locks dict'
with self._tp_locks[tp]:
# iterate over the batches and expire them if they have stayed
# in accumulator for more than request_timeout_ms
for batch in dq:
# check if the batch is expired
if batch.maybe_expire(request_timeout_ms,
self.config['linger_ms']):
expired_batches.append(batch)
count += 1
self.deallocate(batch)
elif not batch.in_retry():
break
if expired_batches:
log.debug("Expired %d batches in accumulator", count) # trace
return expired_batches
def reenqueue(self, batch):
"""Re-enqueue the given record batch in the accumulator to retry."""
now = time.time()
batch.attempts += 1
batch.last_attempt = now
batch.last_append = now
batch.set_retry()
assert batch.topic_partition in self._tp_locks, 'TopicPartition not in locks dict'
assert batch.topic_partition in self._batches, 'TopicPartition not in batches'
dq = self._batches[batch.topic_partition]
with self._tp_locks[batch.topic_partition]:
dq.appendleft(batch)
def ready(self, cluster):
"""
Get a list of nodes whose partitions are ready to be sent, and the
earliest time at which any non-sendable partition will be ready;
Also return the flag for whether there are any unknown leaders for the
accumulated partition batches.
A destination node is ready to send data if ANY one of its partition is
not backing off the send and ANY of the following are true:
* The record set is full
* The record set has sat in the accumulator for at least linger_ms
milliseconds
* The accumulator is out of memory and threads are blocking waiting
for data (in this case all partitions are immediately considered
ready).
* The accumulator has been closed
Arguments:
cluster (ClusterMetadata):
Returns:
tuple:
ready_nodes (set): node_ids that have ready batches
next_ready_check (float): secs until next ready after backoff
unknown_leaders_exist (bool): True if metadata refresh needed
"""
ready_nodes = set()
next_ready_check = 9999999.99
unknown_leaders_exist = False
now = time.time()
exhausted = bool(self._free.queued() > 0)
for tp, dq in six.iteritems(self._batches):
leader = cluster.leader_for_partition(tp)
if leader is None or leader == -1:
unknown_leaders_exist = True
continue
elif leader in ready_nodes:
continue
with self._tp_locks[tp]:
if not dq:
continue
batch = dq[0]
retry_backoff = self.config['retry_backoff_ms'] / 1000.0
linger = self.config['linger_ms'] / 1000.0
backing_off = bool(batch.attempts > 0 and
batch.last_attempt + retry_backoff > now)
waited_time = now - batch.last_attempt
time_to_wait = retry_backoff if backing_off else linger
time_left = max(time_to_wait - waited_time, 0)
full = bool(len(dq) > 1 or batch.records.is_full())
expired = bool(waited_time >= time_to_wait)
sendable = (full or expired or exhausted or self._closed or
self._flush_in_progress())
if sendable and not backing_off:
ready_nodes.add(leader)
else:
# Note that this results in a conservative estimate since
# an un-sendable partition may have a leader that will
# later be found to have sendable data. However, this is
# good enough since we'll just wake up and then sleep again
# for the remaining time.
next_ready_check = min(time_left, next_ready_check)
return ready_nodes, next_ready_check, unknown_leaders_exist
def has_unsent(self):
"""Return whether there is any unsent record in the accumulator."""
for tp, dq in six.iteritems(self._batches):
with self._tp_locks[tp]:
if len(dq):
return True
return False
def drain(self, cluster, nodes, max_size):
"""
Drain all the data for the given nodes and collate them into a list of
batches that will fit within the specified size on a per-node basis.
This method attempts to avoid choosing the same topic-node repeatedly.
Arguments:
cluster (ClusterMetadata): The current cluster metadata
nodes (list): list of node_ids to drain
max_size (int): maximum number of bytes to drain
Returns:
dict: {node_id: list of RecordBatch} with total size less than the
requested max_size.
"""
if not nodes:
return {}
now = time.time()
batches = {}
for node_id in nodes:
size = 0
partitions = list(cluster.partitions_for_broker(node_id))
ready = []
# to make starvation less likely this loop doesn't start at 0
self._drain_index %= len(partitions)
start = self._drain_index
while True:
tp = partitions[self._drain_index]
if tp in self._batches:
with self._tp_locks[tp]:
dq = self._batches[tp]
if dq:
first = dq[0]
backoff = (
bool(first.attempts > 0) and
bool(first.last_attempt +
self.config['retry_backoff_ms'] / 1000.0
> now)
)
# Only drain the batch if it is not during backoff
if not backoff:
if (size + first.records.size_in_bytes() > max_size
and len(ready) > 0):
# there is a rare case that a single batch
# size is larger than the request size due
# to compression; in this case we will
# still eventually send this batch in a
# single request
break
else:
batch = dq.popleft()
batch.records.close()
size += batch.records.size_in_bytes()
ready.append(batch)
batch.drained = now
self._drain_index += 1
self._drain_index %= len(partitions)
if start == self._drain_index:
break
batches[node_id] = ready
return batches
def deallocate(self, batch):
"""Deallocate the record batch."""
self._incomplete.remove(batch)
self._free.deallocate(batch.records.buffer())
def _flush_in_progress(self):
"""Are there any threads currently waiting on a flush?"""
return self._flushes_in_progress.get() > 0
def begin_flush(self):
"""
Initiate the flushing of data from the accumulator...this makes all
requests immediately ready
"""
self._flushes_in_progress.increment()
def await_flush_completion(self):
"""
Mark all partitions as ready to send and block until the send is complete
"""
for batch in self._incomplete.all():
batch.produce_future.await()
self._flushes_in_progress.decrement()
def abort_incomplete_batches(self):
"""
This function is only called when sender is closed forcefully. It will fail all the
incomplete batches and return.
"""
# We need to keep aborting the incomplete batch until no thread is trying to append to
# 1. Avoid losing batches.
# 2. Free up memory in case appending threads are blocked on buffer full.
# This is a tight loop but should be able to get through very quickly.
while True:
self._abort_batches()
if not self._appends_in_progress.get():
break
# After this point, no thread will append any messages because they will see the close
# flag set. We need to do the last abort after no thread was appending in case the there was a new
# batch appended by the last appending thread.
self._abort_batches()
self._batches.clear()
def _abort_batches(self):
"""Go through incomplete batches and abort them."""
error = Errors.IllegalStateError("Producer is closed forcefully.")
for batch in self._incomplete.all():
tp = batch.topic_partition
# Close the batch before aborting
with self._tp_locks[tp]:
batch.records.close()
batch.done(exception=error)
self.deallocate(batch)
def close(self):
"""Close this accumulator and force all the record buffers to be drained."""
self._closed = True
class IncompleteRecordBatches(object):
"""A threadsafe helper class to hold RecordBatches that haven't been ack'd yet"""
def __init__(self):
self._incomplete = set()
self._lock = threading.Lock()
def add(self, batch):
with self._lock:
return self._incomplete.add(batch)
def remove(self, batch):
with self._lock:
return self._incomplete.remove(batch)
def all(self):
with self._lock:
return list(self._incomplete)

272
kafka/producer/sender.py Normal file
View File

@@ -0,0 +1,272 @@
from __future__ import absolute_import
import collections
import copy
import logging
import threading
import time
import six
from ..common import TopicPartition
from ..version import __version__
from ..protocol.produce import ProduceRequest
import kafka.common as Errors
log = logging.getLogger(__name__)
class Sender(threading.Thread):
"""
The background thread that handles the sending of produce requests to the
Kafka cluster. This thread makes metadata requests to renew its view of the
cluster and then sends produce requests to the appropriate nodes.
"""
_DEFAULT_CONFIG = {
'max_request_size': 1048576,
'acks': 1,
'retries': 0,
'request_timeout_ms': 30000,
'client_id': 'kafka-python-' + __version__,
}
def __init__(self, client, metadata, lock, accumulator, **configs):
super(Sender, self).__init__()
self.config = copy.copy(self._DEFAULT_CONFIG)
for key in self.config:
if key in configs:
self.config[key] = configs.pop(key)
self.name = self.config['client_id'] + '-network-thread'
self._client = client
self._accumulator = accumulator
self._metadata = client.cluster
self._lock = lock
self._running = True
self._force_close = False
self._topics_to_add = []
def run(self):
"""The main run loop for the sender thread."""
log.debug("Starting Kafka producer I/O thread.")
# main loop, runs until close is called
while self._running:
try:
self.run_once()
except Exception:
log.exception("Uncaught error in kafka producer I/O thread")
log.debug("Beginning shutdown of Kafka producer I/O thread, sending"
" remaining records.")
# okay we stopped accepting requests but there may still be
# requests in the accumulator or waiting for acknowledgment,
# wait until these are completed.
while (not self._force_close
and (self._accumulator.has_unsent()
or self._client.in_flight_request_count() > 0)):
try:
self.run_once()
except Exception:
log.exception("Uncaught error in kafka producer I/O thread")
if self._force_close:
# We need to fail all the incomplete batches and wake up the
# threads waiting on the futures.
self._accumulator.abort_incomplete_batches()
try:
self._client.close()
except Exception:
log.exception("Failed to close network client")
log.debug("Shutdown of Kafka producer I/O thread has completed.")
def run_once(self):
"""Run a single iteration of sending."""
while self._topics_to_add:
self._client.add_topic(self._topics_to_add.pop())
# get the list of partitions with data ready to send
result = self._accumulator.ready(self._metadata)
ready_nodes, next_ready_check_delay, unknown_leaders_exist = result
# if there are any partitions whose leaders are not known yet, force
# metadata update
if unknown_leaders_exist:
with self._lock:
self._metadata.request_update()
# remove any nodes we aren't ready to send to
not_ready_timeout = 999999999
for node in list(ready_nodes):
if not self._client.ready(node):
ready_nodes.remove(node)
not_ready_timeout = min(not_ready_timeout,
self._client.connection_delay(node))
# create produce requests
batches_by_node = self._accumulator.drain(
self._metadata, ready_nodes, self.config['max_request_size'])
expired_batches = self._accumulator.abort_expired_batches(
self.config['request_timeout_ms'], self._metadata)
requests = self._create_produce_requests(batches_by_node)
# If we have any nodes that are ready to send + have sendable data,
# poll with 0 timeout so this can immediately loop and try sending more
# data. Otherwise, the timeout is determined by nodes that have
# partitions with data that isn't yet sendable (e.g. lingering, backing
# off). Note that this specifically does not include nodes with
# sendable data that aren't ready to send since they would cause busy
# looping.
poll_timeout_ms = min(next_ready_check_delay * 1000, not_ready_timeout)
if ready_nodes:
log.debug("Nodes with data ready to send: %s", ready_nodes) # trace
log.debug("Created %d produce requests: %s", len(requests), requests) # trace
poll_timeout_ms = 0
with self._lock:
for node_id, request in six.iteritems(requests):
batches = batches_by_node[node_id]
log.debug('Sending Produce Request: %r', request)
(self._client.send(node_id, request)
.add_callback(
self._handle_produce_response, batches)
.add_errback(
self._failed_produce, batches, node_id))
# if some partitions are already ready to be sent, the select time
# would be 0; otherwise if some partition already has some data
# accumulated but not ready yet, the select time will be the time
# difference between now and its linger expiry time; otherwise the
# select time will be the time difference between now and the
# metadata expiry time
self._client.poll(poll_timeout_ms, sleep=True)
def initiate_close(self):
"""Start closing the sender (won't complete until all data is sent)."""
self._running = False
self._accumulator.close()
self.wakeup()
def force_close(self):
"""Closes the sender without sending out any pending messages."""
self._force_close = True
self.initiate_close()
def add_topic(self, topic):
self._topics_to_add.append(topic)
self.wakeup()
def _failed_produce(self, batches, node_id, error):
log.debug("Error sending produce request to node %d: %s", node_id, error) # trace
for batch in batches:
self._complete_batch(batch, error, -1)
def _handle_produce_response(self, batches, response):
"""Handle a produce response."""
# if we have a response, parse it
log.debug('Parsing produce response: %r', response)
if response:
batches_by_partition = dict([(batch.topic_partition, batch)
for batch in batches])
for topic, partitions in response.topics:
for partition, error_code, offset in partitions:
tp = TopicPartition(topic, partition)
error = Errors.for_code(error_code)
batch = batches_by_partition[tp]
self._complete_batch(batch, error, offset)
else:
# this is the acks = 0 case, just complete all requests
for batch in batches:
self._complete_batch(batch, None, -1)
def _complete_batch(self, batch, error, base_offset):
"""Complete or retry the given batch of records.
Arguments:
batch (RecordBatch): The record batch
error (Exception): The error (or None if none)
base_offset (int): The base offset assigned to the records if successful
"""
# Standardize no-error to None
if error is Errors.NoError:
error = None
if error is not None and self._can_retry(batch, error):
# retry
log.warning("Got error produce response on topic-partition %s,"
" retrying (%d attempts left). Error: %s",
batch.topic_partition,
self.config['retries'] - batch.attempts - 1,
error)
self._accumulator.reenqueue(batch)
else:
if error is Errors.TopicAuthorizationFailedError:
error = error(batch.topic_partition.topic)
# tell the user the result of their request
batch.done(base_offset, error)
self._accumulator.deallocate(batch)
if getattr(error, 'invalid_metadata', False):
self._metadata.request_update()
def _can_retry(self, batch, error):
"""
We can retry a send if the error is transient and the number of
attempts taken is fewer than the maximum allowed
"""
return (batch.attempts < self.config['retries']
and getattr(error, 'retriable', False))
def _create_produce_requests(self, collated):
"""
Transfer the record batches into a list of produce requests on a
per-node basis.
Arguments:
collated: {node_id: [RecordBatch]}
Returns:
dict: {node_id: ProduceRequest}
"""
requests = {}
for node_id, batches in six.iteritems(collated):
requests[node_id] = self._produce_request(
node_id, self.config['acks'],
self.config['request_timeout_ms'], batches)
return requests
def _produce_request(self, node_id, acks, timeout, batches):
"""Create a produce request from the given record batches.
Returns:
ProduceRequest
"""
produce_records_by_partition = collections.defaultdict(dict)
for batch in batches:
topic = batch.topic_partition.topic
partition = batch.topic_partition.partition
# TODO: bytearray / memoryview
buf = batch.records.buffer()
produce_records_by_partition[topic][partition] = buf
return ProduceRequest(
required_acks=acks,
timeout=timeout,
topics=[(topic, list(partition_info.items()))
for topic, partition_info
in six.iteritems(produce_records_by_partition)]
)
def wakeup(self):
"""Wake up the selector associated with this send thread."""
self._client.wakeup()

View File

@@ -20,6 +20,7 @@ class Message(Struct):
CODEC_MASK = 0x03
CODEC_GZIP = 0x01
CODEC_SNAPPY = 0x02
HEADER_SIZE = 14 # crc(4), magic(1), attributes(1), key+value size(4*2)
def __init__(self, value, key=None, magic=0, attributes=0, crc=0):
assert value is None or isinstance(value, bytes), 'value must be bytes'
@@ -83,9 +84,17 @@ class MessageSet(AbstractType):
('message_size', Int32),
('message', Message.SCHEMA)
)
HEADER_SIZE = 12 # offset + message_size
@classmethod
def encode(cls, items, size=True, recalc_message_size=True):
# RecordAccumulator encodes messagesets internally
if isinstance(items, io.BytesIO):
size = Int32.decode(items)
# rewind and return all the bytes
items.seek(-4, 1)
return items.read(size + 4)
encoded_values = []
for (offset, message_size, message) in items:
if isinstance(message, Message):
@@ -141,4 +150,9 @@ class MessageSet(AbstractType):
@classmethod
def repr(cls, messages):
if isinstance(messages, io.BytesIO):
offset = messages.tell()
decoded = cls.decode(messages)
messages.seek(offset)
messages = decoded
return '[' + ', '.join([cls.ITEM.repr(m) for m in messages]) + ']'

33
test/conftest.py Normal file
View File

@@ -0,0 +1,33 @@
import os
import pytest
from test.fixtures import KafkaFixture, ZookeeperFixture
@pytest.fixture(scope="module")
def version():
if 'KAFKA_VERSION' not in os.environ:
return ()
return tuple(map(int, os.environ['KAFKA_VERSION'].split('.')))
@pytest.fixture(scope="module")
def zookeeper(version, request):
assert version
zk = ZookeeperFixture.instance()
def fin():
zk.close()
request.addfinalizer(fin)
return zk
@pytest.fixture(scope="module")
def kafka_broker(version, zookeeper, request):
assert version
k = KafkaFixture.instance(0, zookeeper.host, zookeeper.port,
partitions=4)
def fin():
k.close()
request.addfinalizer(fin)
return k

View File

@@ -5,10 +5,11 @@ import shutil
import subprocess
import tempfile
import time
from six.moves import urllib
import uuid
from six.moves import urllib
from six.moves.urllib.parse import urlparse # pylint: disable=E0611,F0401
from test.service import ExternalService, SpawnedService
from test.testutil import get_open_port

View File

@@ -12,38 +12,10 @@ from kafka.common import TopicPartition
from kafka.conn import BrokerConnection, ConnectionStates
from kafka.consumer.group import KafkaConsumer
from test.fixtures import KafkaFixture, ZookeeperFixture
from test.conftest import version
from test.testutil import random_string
@pytest.fixture(scope="module")
def version():
if 'KAFKA_VERSION' not in os.environ:
return ()
return tuple(map(int, os.environ['KAFKA_VERSION'].split('.')))
@pytest.fixture(scope="module")
def zookeeper(version, request):
assert version
zk = ZookeeperFixture.instance()
def fin():
zk.close()
request.addfinalizer(fin)
return zk
@pytest.fixture(scope="module")
def kafka_broker(version, zookeeper, request):
assert version
k = KafkaFixture.instance(0, zookeeper.host, zookeeper.port,
partitions=4)
def fin():
k.close()
request.addfinalizer(fin)
return k
@pytest.fixture
def simple_client(kafka_broker):
connect_str = 'localhost:' + str(kafka_broker.port)

View File

@@ -1,23 +1,43 @@
import pytest
import six
from . import unittest
from kafka.partitioner import (Murmur2Partitioner)
from kafka.partitioner import Murmur2Partitioner
from kafka.partitioner.default import DefaultPartitioner
class TestMurmurPartitioner(unittest.TestCase):
def test_hash_bytes(self):
p = Murmur2Partitioner(range(1000))
self.assertEqual(p.partition(bytearray(b'test')), p.partition(b'test'))
def test_hash_encoding(self):
p = Murmur2Partitioner(range(1000))
self.assertEqual(p.partition('test'), p.partition(u'test'))
def test_default_partitioner():
partitioner = DefaultPartitioner()
all_partitions = list(range(100))
available = all_partitions
# partitioner should return the same partition for the same key
p1 = partitioner(b'foo', all_partitions, available)
p2 = partitioner(b'foo', all_partitions, available)
assert p1 == p2
assert p1 in all_partitions
def test_murmur2_java_compatibility(self):
p = Murmur2Partitioner(range(1000))
# compare with output from Kafka's org.apache.kafka.clients.producer.Partitioner
self.assertEqual(681, p.partition(b''))
self.assertEqual(524, p.partition(b'a'))
self.assertEqual(434, p.partition(b'ab'))
self.assertEqual(107, p.partition(b'abc'))
self.assertEqual(566, p.partition(b'123456789'))
self.assertEqual(742, p.partition(b'\x00 '))
# when key is None, choose one of available partitions
assert partitioner(None, all_partitions, [123]) == 123
# with fallback to all_partitions
assert partitioner(None, all_partitions, []) in all_partitions
def test_hash_bytes():
p = Murmur2Partitioner(range(1000))
assert p.partition(bytearray(b'test')) == p.partition(b'test')
def test_hash_encoding():
p = Murmur2Partitioner(range(1000))
assert p.partition('test') == p.partition(u'test')
def test_murmur2_java_compatibility():
p = Murmur2Partitioner(range(1000))
# compare with output from Kafka's org.apache.kafka.clients.producer.Partitioner
assert p.partition(b'') == 681
assert p.partition(b'a') == 524
assert p.partition(b'ab') == 434
assert p.partition(b'abc') == 107
assert p.partition(b'123456789') == 566
assert p.partition(b'\x00 ') == 742

View File

@@ -1,257 +1,34 @@
# -*- coding: utf-8 -*-
import pytest
import collections
import logging
import threading
import time
from mock import MagicMock, patch
from . import unittest
from kafka import SimpleClient, SimpleProducer, KeyedProducer
from kafka.common import (
AsyncProducerQueueFull, FailedPayloadsError, NotLeaderForPartitionError,
ProduceResponsePayload, RetryOptions, TopicPartition
)
from kafka.producer.base import Producer, _send_upstream
from kafka.protocol import CODEC_NONE
from six.moves import queue, xrange
from kafka import KafkaConsumer, KafkaProducer
from test.conftest import version
from test.testutil import random_string
class TestKafkaProducer(unittest.TestCase):
def test_producer_message_types(self):
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_end_to_end(kafka_broker):
connect_str = 'localhost:' + str(kafka_broker.port)
producer = KafkaProducer(bootstrap_servers=connect_str,
max_block_ms=10000,
value_serializer=str.encode)
consumer = KafkaConsumer(bootstrap_servers=connect_str,
consumer_timeout_ms=10000,
auto_offset_reset='earliest',
value_deserializer=bytes.decode)
producer = Producer(MagicMock())
topic = b"test-topic"
partition = 0
topic = random_string(5)
bad_data_types = (u'你怎么样?', 12, ['a', 'list'],
('a', 'tuple'), {'a': 'dict'}, None,)
for m in bad_data_types:
with self.assertRaises(TypeError):
logging.debug("attempting to send message of type %s", type(m))
producer.send_messages(topic, partition, m)
for i in range(1000):
producer.send(topic, 'msg %d' % i)
producer.flush()
producer.close()
good_data_types = (b'a string!',)
for m in good_data_types:
# This should not raise an exception
producer.send_messages(topic, partition, m)
consumer.subscribe([topic])
msgs = set()
for i in range(1000):
try:
msgs.add(next(consumer).value)
except StopIteration:
break
def test_keyedproducer_message_types(self):
client = MagicMock()
client.get_partition_ids_for_topic.return_value = [0, 1]
producer = KeyedProducer(client)
topic = b"test-topic"
key = b"testkey"
bad_data_types = (u'你怎么样?', 12, ['a', 'list'],
('a', 'tuple'), {'a': 'dict'},)
for m in bad_data_types:
with self.assertRaises(TypeError):
logging.debug("attempting to send message of type %s", type(m))
producer.send_messages(topic, key, m)
good_data_types = (b'a string!', None,)
for m in good_data_types:
# This should not raise an exception
producer.send_messages(topic, key, m)
def test_topic_message_types(self):
client = MagicMock()
def partitions(topic):
return [0, 1]
client.get_partition_ids_for_topic = partitions
producer = SimpleProducer(client, random_start=False)
topic = b"test-topic"
producer.send_messages(topic, b'hi')
assert client.send_produce_request.called
@patch('kafka.producer.base._send_upstream')
def test_producer_async_queue_overfilled(self, mock):
queue_size = 2
producer = Producer(MagicMock(), async=True,
async_queue_maxsize=queue_size)
topic = b'test-topic'
partition = 0
message = b'test-message'
with self.assertRaises(AsyncProducerQueueFull):
message_list = [message] * (queue_size + 1)
producer.send_messages(topic, partition, *message_list)
self.assertEqual(producer.queue.qsize(), queue_size)
for _ in xrange(producer.queue.qsize()):
producer.queue.get()
def test_producer_sync_fail_on_error(self):
error = FailedPayloadsError('failure')
with patch.object(SimpleClient, 'load_metadata_for_topics'):
with patch.object(SimpleClient, 'ensure_topic_exists'):
with patch.object(SimpleClient, 'get_partition_ids_for_topic', return_value=[0, 1]):
with patch.object(SimpleClient, '_send_broker_aware_request', return_value = [error]):
client = SimpleClient(MagicMock())
producer = SimpleProducer(client, async=False, sync_fail_on_error=False)
# This should not raise
(response,) = producer.send_messages('foobar', b'test message')
self.assertEqual(response, error)
producer = SimpleProducer(client, async=False, sync_fail_on_error=True)
with self.assertRaises(FailedPayloadsError):
producer.send_messages('foobar', b'test message')
def test_cleanup_is_not_called_on_stopped_producer(self):
producer = Producer(MagicMock(), async=True)
producer.stopped = True
with patch.object(producer, 'stop') as mocked_stop:
producer._cleanup_func(producer)
self.assertEqual(mocked_stop.call_count, 0)
def test_cleanup_is_called_on_running_producer(self):
producer = Producer(MagicMock(), async=True)
producer.stopped = False
with patch.object(producer, 'stop') as mocked_stop:
producer._cleanup_func(producer)
self.assertEqual(mocked_stop.call_count, 1)
class TestKafkaProducerSendUpstream(unittest.TestCase):
def setUp(self):
self.client = MagicMock()
self.queue = queue.Queue()
def _run_process(self, retries_limit=3, sleep_timeout=1):
# run _send_upstream process with the queue
stop_event = threading.Event()
retry_options = RetryOptions(limit=retries_limit,
backoff_ms=50,
retry_on_timeouts=False)
self.thread = threading.Thread(
target=_send_upstream,
args=(self.queue, self.client, CODEC_NONE,
0.3, # batch time (seconds)
3, # batch length
Producer.ACK_AFTER_LOCAL_WRITE,
Producer.DEFAULT_ACK_TIMEOUT,
retry_options,
stop_event))
self.thread.daemon = True
self.thread.start()
time.sleep(sleep_timeout)
stop_event.set()
def test_wo_retries(self):
# lets create a queue and add 10 messages for 1 partition
for i in range(10):
self.queue.put((TopicPartition("test", 0), "msg %i", "key %i"))
self._run_process()
# the queue should be void at the end of the test
self.assertEqual(self.queue.empty(), True)
# there should be 4 non-void cals:
# 3 batches of 3 msgs each + 1 batch of 1 message
self.assertEqual(self.client.send_produce_request.call_count, 4)
def test_first_send_failed(self):
# lets create a queue and add 10 messages for 10 different partitions
# to show how retries should work ideally
for i in range(10):
self.queue.put((TopicPartition("test", i), "msg %i", "key %i"))
# Mock offsets counter for closure
offsets = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
self.client.is_first_time = True
def send_side_effect(reqs, *args, **kwargs):
if self.client.is_first_time:
self.client.is_first_time = False
return [FailedPayloadsError(req) for req in reqs]
responses = []
for req in reqs:
offset = offsets[req.topic][req.partition]
offsets[req.topic][req.partition] += len(req.messages)
responses.append(
ProduceResponsePayload(req.topic, req.partition, 0, offset)
)
return responses
self.client.send_produce_request.side_effect = send_side_effect
self._run_process(2)
# the queue should be void at the end of the test
self.assertEqual(self.queue.empty(), True)
# there should be 5 non-void calls: 1st failed batch of 3 msgs
# plus 3 batches of 3 msgs each + 1 batch of 1 message
self.assertEqual(self.client.send_produce_request.call_count, 5)
def test_with_limited_retries(self):
# lets create a queue and add 10 messages for 10 different partitions
# to show how retries should work ideally
for i in range(10):
self.queue.put((TopicPartition("test", i), "msg %i" % i, "key %i" % i))
def send_side_effect(reqs, *args, **kwargs):
return [FailedPayloadsError(req) for req in reqs]
self.client.send_produce_request.side_effect = send_side_effect
self._run_process(3, 3)
# the queue should be void at the end of the test
self.assertEqual(self.queue.empty(), True)
# there should be 16 non-void calls:
# 3 initial batches of 3 msgs each + 1 initial batch of 1 msg +
# 3 retries of the batches above = (1 + 3 retries) * 4 batches = 16
self.assertEqual(self.client.send_produce_request.call_count, 16)
def test_async_producer_not_leader(self):
for i in range(10):
self.queue.put((TopicPartition("test", i), "msg %i", "key %i"))
# Mock offsets counter for closure
offsets = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
self.client.is_first_time = True
def send_side_effect(reqs, *args, **kwargs):
if self.client.is_first_time:
self.client.is_first_time = False
return [ProduceResponsePayload(req.topic, req.partition,
NotLeaderForPartitionError.errno, -1)
for req in reqs]
responses = []
for req in reqs:
offset = offsets[req.topic][req.partition]
offsets[req.topic][req.partition] += len(req.messages)
responses.append(
ProduceResponsePayload(req.topic, req.partition, 0, offset)
)
return responses
self.client.send_produce_request.side_effect = send_side_effect
self._run_process(2)
# the queue should be void at the end of the test
self.assertEqual(self.queue.empty(), True)
# there should be 5 non-void calls: 1st failed batch of 3 msgs
# + 3 batches of 3 msgs each + 1 batch of 1 msg = 1 + 3 + 1 = 5
self.assertEqual(self.client.send_produce_request.call_count, 5)
def tearDown(self):
for _ in xrange(self.queue.qsize()):
self.queue.get()
assert msgs == set(['msg %d' % i for i in range(1000)])

View File

@@ -0,0 +1,257 @@
# -*- coding: utf-8 -*-
import collections
import logging
import threading
import time
from mock import MagicMock, patch
from . import unittest
from kafka import SimpleClient, SimpleProducer, KeyedProducer
from kafka.common import (
AsyncProducerQueueFull, FailedPayloadsError, NotLeaderForPartitionError,
ProduceResponsePayload, RetryOptions, TopicPartition
)
from kafka.producer.base import Producer, _send_upstream
from kafka.protocol import CODEC_NONE
from six.moves import queue, xrange
class TestKafkaProducer(unittest.TestCase):
def test_producer_message_types(self):
producer = Producer(MagicMock())
topic = b"test-topic"
partition = 0
bad_data_types = (u'你怎么样?', 12, ['a', 'list'],
('a', 'tuple'), {'a': 'dict'}, None,)
for m in bad_data_types:
with self.assertRaises(TypeError):
logging.debug("attempting to send message of type %s", type(m))
producer.send_messages(topic, partition, m)
good_data_types = (b'a string!',)
for m in good_data_types:
# This should not raise an exception
producer.send_messages(topic, partition, m)
def test_keyedproducer_message_types(self):
client = MagicMock()
client.get_partition_ids_for_topic.return_value = [0, 1]
producer = KeyedProducer(client)
topic = b"test-topic"
key = b"testkey"
bad_data_types = (u'你怎么样?', 12, ['a', 'list'],
('a', 'tuple'), {'a': 'dict'},)
for m in bad_data_types:
with self.assertRaises(TypeError):
logging.debug("attempting to send message of type %s", type(m))
producer.send_messages(topic, key, m)
good_data_types = (b'a string!', None,)
for m in good_data_types:
# This should not raise an exception
producer.send_messages(topic, key, m)
def test_topic_message_types(self):
client = MagicMock()
def partitions(topic):
return [0, 1]
client.get_partition_ids_for_topic = partitions
producer = SimpleProducer(client, random_start=False)
topic = b"test-topic"
producer.send_messages(topic, b'hi')
assert client.send_produce_request.called
@patch('kafka.producer.base._send_upstream')
def test_producer_async_queue_overfilled(self, mock):
queue_size = 2
producer = Producer(MagicMock(), async=True,
async_queue_maxsize=queue_size)
topic = b'test-topic'
partition = 0
message = b'test-message'
with self.assertRaises(AsyncProducerQueueFull):
message_list = [message] * (queue_size + 1)
producer.send_messages(topic, partition, *message_list)
self.assertEqual(producer.queue.qsize(), queue_size)
for _ in xrange(producer.queue.qsize()):
producer.queue.get()
def test_producer_sync_fail_on_error(self):
error = FailedPayloadsError('failure')
with patch.object(SimpleClient, 'load_metadata_for_topics'):
with patch.object(SimpleClient, 'ensure_topic_exists'):
with patch.object(SimpleClient, 'get_partition_ids_for_topic', return_value=[0, 1]):
with patch.object(SimpleClient, '_send_broker_aware_request', return_value = [error]):
client = SimpleClient(MagicMock())
producer = SimpleProducer(client, async=False, sync_fail_on_error=False)
# This should not raise
(response,) = producer.send_messages('foobar', b'test message')
self.assertEqual(response, error)
producer = SimpleProducer(client, async=False, sync_fail_on_error=True)
with self.assertRaises(FailedPayloadsError):
producer.send_messages('foobar', b'test message')
def test_cleanup_is_not_called_on_stopped_producer(self):
producer = Producer(MagicMock(), async=True)
producer.stopped = True
with patch.object(producer, 'stop') as mocked_stop:
producer._cleanup_func(producer)
self.assertEqual(mocked_stop.call_count, 0)
def test_cleanup_is_called_on_running_producer(self):
producer = Producer(MagicMock(), async=True)
producer.stopped = False
with patch.object(producer, 'stop') as mocked_stop:
producer._cleanup_func(producer)
self.assertEqual(mocked_stop.call_count, 1)
class TestKafkaProducerSendUpstream(unittest.TestCase):
def setUp(self):
self.client = MagicMock()
self.queue = queue.Queue()
def _run_process(self, retries_limit=3, sleep_timeout=1):
# run _send_upstream process with the queue
stop_event = threading.Event()
retry_options = RetryOptions(limit=retries_limit,
backoff_ms=50,
retry_on_timeouts=False)
self.thread = threading.Thread(
target=_send_upstream,
args=(self.queue, self.client, CODEC_NONE,
0.3, # batch time (seconds)
3, # batch length
Producer.ACK_AFTER_LOCAL_WRITE,
Producer.DEFAULT_ACK_TIMEOUT,
retry_options,
stop_event))
self.thread.daemon = True
self.thread.start()
time.sleep(sleep_timeout)
stop_event.set()
def test_wo_retries(self):
# lets create a queue and add 10 messages for 1 partition
for i in range(10):
self.queue.put((TopicPartition("test", 0), "msg %i", "key %i"))
self._run_process()
# the queue should be void at the end of the test
self.assertEqual(self.queue.empty(), True)
# there should be 4 non-void cals:
# 3 batches of 3 msgs each + 1 batch of 1 message
self.assertEqual(self.client.send_produce_request.call_count, 4)
def test_first_send_failed(self):
# lets create a queue and add 10 messages for 10 different partitions
# to show how retries should work ideally
for i in range(10):
self.queue.put((TopicPartition("test", i), "msg %i", "key %i"))
# Mock offsets counter for closure
offsets = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
self.client.is_first_time = True
def send_side_effect(reqs, *args, **kwargs):
if self.client.is_first_time:
self.client.is_first_time = False
return [FailedPayloadsError(req) for req in reqs]
responses = []
for req in reqs:
offset = offsets[req.topic][req.partition]
offsets[req.topic][req.partition] += len(req.messages)
responses.append(
ProduceResponsePayload(req.topic, req.partition, 0, offset)
)
return responses
self.client.send_produce_request.side_effect = send_side_effect
self._run_process(2)
# the queue should be void at the end of the test
self.assertEqual(self.queue.empty(), True)
# there should be 5 non-void calls: 1st failed batch of 3 msgs
# plus 3 batches of 3 msgs each + 1 batch of 1 message
self.assertEqual(self.client.send_produce_request.call_count, 5)
def test_with_limited_retries(self):
# lets create a queue and add 10 messages for 10 different partitions
# to show how retries should work ideally
for i in range(10):
self.queue.put((TopicPartition("test", i), "msg %i" % i, "key %i" % i))
def send_side_effect(reqs, *args, **kwargs):
return [FailedPayloadsError(req) for req in reqs]
self.client.send_produce_request.side_effect = send_side_effect
self._run_process(3, 3)
# the queue should be void at the end of the test
self.assertEqual(self.queue.empty(), True)
# there should be 16 non-void calls:
# 3 initial batches of 3 msgs each + 1 initial batch of 1 msg +
# 3 retries of the batches above = (1 + 3 retries) * 4 batches = 16
self.assertEqual(self.client.send_produce_request.call_count, 16)
def test_async_producer_not_leader(self):
for i in range(10):
self.queue.put((TopicPartition("test", i), "msg %i", "key %i"))
# Mock offsets counter for closure
offsets = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
self.client.is_first_time = True
def send_side_effect(reqs, *args, **kwargs):
if self.client.is_first_time:
self.client.is_first_time = False
return [ProduceResponsePayload(req.topic, req.partition,
NotLeaderForPartitionError.errno, -1)
for req in reqs]
responses = []
for req in reqs:
offset = offsets[req.topic][req.partition]
offsets[req.topic][req.partition] += len(req.messages)
responses.append(
ProduceResponsePayload(req.topic, req.partition, 0, offset)
)
return responses
self.client.send_produce_request.side_effect = send_side_effect
self._run_process(2)
# the queue should be void at the end of the test
self.assertEqual(self.queue.empty(), True)
# there should be 5 non-void calls: 1st failed batch of 3 msgs
# + 3 batches of 3 msgs each + 1 batch of 1 msg = 1 + 3 + 1 = 5
self.assertEqual(self.client.send_produce_request.call_count, 5)
def tearDown(self):
for _ in xrange(self.queue.qsize()):
self.queue.get()