Merge branch 'vshlapakov-feature-async-threading'
PR 330: Threading for async batching Conflicts: kafka/producer/base.py
This commit is contained in:
@@ -62,6 +62,9 @@ class KafkaConnection(local):
|
||||
|
||||
self.reinit()
|
||||
|
||||
def __getnewargs__(self):
|
||||
return (self.host, self.port, self.timeout)
|
||||
|
||||
def __repr__(self):
|
||||
return "<KafkaConnection host=%s port=%d>" % (self.host, self.port)
|
||||
|
||||
|
||||
@@ -4,11 +4,12 @@ import logging
|
||||
import time
|
||||
|
||||
try:
|
||||
from queue import Empty
|
||||
from queue import Empty, Queue
|
||||
except ImportError:
|
||||
from Queue import Empty
|
||||
from Queue import Empty, Queue
|
||||
from collections import defaultdict
|
||||
from multiprocessing import Queue, Process
|
||||
|
||||
from threading import Thread, Event
|
||||
|
||||
import six
|
||||
|
||||
@@ -26,20 +27,15 @@ STOP_ASYNC_PRODUCER = -1
|
||||
|
||||
|
||||
def _send_upstream(queue, client, codec, batch_time, batch_size,
|
||||
req_acks, ack_timeout):
|
||||
req_acks, ack_timeout, stop_event):
|
||||
"""
|
||||
Listen on the queue for a specified number of messages or till
|
||||
a specified timeout and send them upstream to the brokers in one
|
||||
request
|
||||
|
||||
NOTE: Ideally, this should have been a method inside the Producer
|
||||
class. However, multiprocessing module has issues in windows. The
|
||||
functionality breaks unless this function is kept outside of a class
|
||||
"""
|
||||
stop = False
|
||||
client.reinit()
|
||||
|
||||
while not stop:
|
||||
while not stop_event.is_set():
|
||||
timeout = batch_time
|
||||
count = batch_size
|
||||
send_at = time.time() + timeout
|
||||
@@ -56,7 +52,7 @@ def _send_upstream(queue, client, codec, batch_time, batch_size,
|
||||
|
||||
# Check if the controller has requested us to stop
|
||||
if topic_partition == STOP_ASYNC_PRODUCER:
|
||||
stop = True
|
||||
stop_event.set()
|
||||
break
|
||||
|
||||
# Adjust the timeout to match the remaining period
|
||||
@@ -141,18 +137,22 @@ class Producer(object):
|
||||
log.warning("Current implementation does not retry Failed messages")
|
||||
log.warning("Use at your own risk! (or help improve with a PR!)")
|
||||
self.queue = Queue() # Messages are sent through this queue
|
||||
self.proc = Process(target=_send_upstream,
|
||||
args=(self.queue,
|
||||
self.client.copy(),
|
||||
self.codec,
|
||||
batch_send_every_t,
|
||||
batch_send_every_n,
|
||||
self.req_acks,
|
||||
self.ack_timeout))
|
||||
self.thread_stop_event = Event()
|
||||
self.thread = Thread(target=_send_upstream,
|
||||
args=(self.queue,
|
||||
self.client.copy(),
|
||||
self.codec,
|
||||
batch_send_every_t,
|
||||
batch_send_every_n,
|
||||
self.req_acks,
|
||||
self.ack_timeout,
|
||||
self.thread_stop_event))
|
||||
|
||||
# Thread will die if main thread exits
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
|
||||
# Process will die if main thread exits
|
||||
self.proc.daemon = True
|
||||
self.proc.start()
|
||||
|
||||
def send_messages(self, topic, partition, *msg):
|
||||
"""
|
||||
@@ -209,10 +209,10 @@ class Producer(object):
|
||||
"""
|
||||
if self.async:
|
||||
self.queue.put((STOP_ASYNC_PRODUCER, None, None))
|
||||
self.proc.join(timeout)
|
||||
self.thread.join(timeout)
|
||||
|
||||
if self.proc.is_alive():
|
||||
self.proc.terminate()
|
||||
if self.thread.is_alive():
|
||||
self.thread_stop_event.set()
|
||||
self.stopped = True
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import socket
|
||||
import struct
|
||||
from threading import Thread
|
||||
|
||||
import mock
|
||||
from . import unittest
|
||||
@@ -162,3 +163,46 @@ class ConnTest(unittest.TestCase):
|
||||
self.conn.send(self.config['request_id'], self.config['payload'])
|
||||
self.assertEqual(self.MockCreateConn.call_count, 1)
|
||||
self.conn._sock.sendall.assert_called_with(self.config['payload'])
|
||||
|
||||
|
||||
class TestKafkaConnection(unittest.TestCase):
|
||||
|
||||
@mock.patch('socket.create_connection')
|
||||
def test_copy(self, socket):
|
||||
"""KafkaConnection copies work as expected"""
|
||||
|
||||
conn = KafkaConnection('kafka', 9092)
|
||||
self.assertEqual(socket.call_count, 1)
|
||||
|
||||
copy = conn.copy()
|
||||
self.assertEqual(socket.call_count, 1)
|
||||
self.assertEqual(copy.host, 'kafka')
|
||||
self.assertEqual(copy.port, 9092)
|
||||
self.assertEqual(copy._sock, None)
|
||||
|
||||
copy.reinit()
|
||||
self.assertEqual(socket.call_count, 2)
|
||||
self.assertNotEqual(copy._sock, None)
|
||||
|
||||
@mock.patch('socket.create_connection')
|
||||
def test_copy_thread(self, socket):
|
||||
"""KafkaConnection copies work in other threads"""
|
||||
|
||||
err = []
|
||||
copy = KafkaConnection('kafka', 9092).copy()
|
||||
|
||||
def thread_func(err, copy):
|
||||
try:
|
||||
self.assertEqual(copy.host, 'kafka')
|
||||
self.assertEqual(copy.port, 9092)
|
||||
self.assertNotEqual(copy._sock, None)
|
||||
except Exception as e:
|
||||
err.append(e)
|
||||
else:
|
||||
err.append(None)
|
||||
thread = Thread(target=thread_func, args=(err, copy))
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
self.assertEqual(err, [None])
|
||||
self.assertEqual(socket.call_count, 2)
|
||||
|
||||
Reference in New Issue
Block a user