Merge branch 'vshlapakov-feature-async-threading'

PR 330: Threading for async batching

Conflicts:
	kafka/producer/base.py
This commit is contained in:
Dana Powers
2015-03-08 16:03:06 -07:00
3 changed files with 72 additions and 25 deletions

View File

@@ -62,6 +62,9 @@ class KafkaConnection(local):
self.reinit() self.reinit()
def __getnewargs__(self):
return (self.host, self.port, self.timeout)
def __repr__(self): def __repr__(self):
return "<KafkaConnection host=%s port=%d>" % (self.host, self.port) return "<KafkaConnection host=%s port=%d>" % (self.host, self.port)

View File

@@ -4,11 +4,12 @@ import logging
import time import time
try: try:
from queue import Empty from queue import Empty, Queue
except ImportError: except ImportError:
from Queue import Empty from Queue import Empty, Queue
from collections import defaultdict from collections import defaultdict
from multiprocessing import Queue, Process
from threading import Thread, Event
import six import six
@@ -26,20 +27,15 @@ STOP_ASYNC_PRODUCER = -1
def _send_upstream(queue, client, codec, batch_time, batch_size, def _send_upstream(queue, client, codec, batch_time, batch_size,
req_acks, ack_timeout): req_acks, ack_timeout, stop_event):
""" """
Listen on the queue for a specified number of messages or till Listen on the queue for a specified number of messages or till
a specified timeout and send them upstream to the brokers in one a specified timeout and send them upstream to the brokers in one
request request
NOTE: Ideally, this should have been a method inside the Producer
class. However, multiprocessing module has issues in windows. The
functionality breaks unless this function is kept outside of a class
""" """
stop = False stop = False
client.reinit()
while not stop: while not stop_event.is_set():
timeout = batch_time timeout = batch_time
count = batch_size count = batch_size
send_at = time.time() + timeout send_at = time.time() + timeout
@@ -56,7 +52,7 @@ def _send_upstream(queue, client, codec, batch_time, batch_size,
# Check if the controller has requested us to stop # Check if the controller has requested us to stop
if topic_partition == STOP_ASYNC_PRODUCER: if topic_partition == STOP_ASYNC_PRODUCER:
stop = True stop_event.set()
break break
# Adjust the timeout to match the remaining period # Adjust the timeout to match the remaining period
@@ -141,18 +137,22 @@ class Producer(object):
log.warning("Current implementation does not retry Failed messages") log.warning("Current implementation does not retry Failed messages")
log.warning("Use at your own risk! (or help improve with a PR!)") log.warning("Use at your own risk! (or help improve with a PR!)")
self.queue = Queue() # Messages are sent through this queue self.queue = Queue() # Messages are sent through this queue
self.proc = Process(target=_send_upstream, self.thread_stop_event = Event()
self.thread = Thread(target=_send_upstream,
args=(self.queue, args=(self.queue,
self.client.copy(), self.client.copy(),
self.codec, self.codec,
batch_send_every_t, batch_send_every_t,
batch_send_every_n, batch_send_every_n,
self.req_acks, self.req_acks,
self.ack_timeout)) self.ack_timeout,
self.thread_stop_event))
# Thread will die if main thread exits
self.thread.daemon = True
self.thread.start()
# Process will die if main thread exits
self.proc.daemon = True
self.proc.start()
def send_messages(self, topic, partition, *msg): def send_messages(self, topic, partition, *msg):
""" """
@@ -209,10 +209,10 @@ class Producer(object):
""" """
if self.async: if self.async:
self.queue.put((STOP_ASYNC_PRODUCER, None, None)) self.queue.put((STOP_ASYNC_PRODUCER, None, None))
self.proc.join(timeout) self.thread.join(timeout)
if self.proc.is_alive(): if self.thread.is_alive():
self.proc.terminate() self.thread_stop_event.set()
self.stopped = True self.stopped = True
def __del__(self): def __del__(self):

View File

@@ -1,5 +1,6 @@
import socket import socket
import struct import struct
from threading import Thread
import mock import mock
from . import unittest from . import unittest
@@ -162,3 +163,46 @@ class ConnTest(unittest.TestCase):
self.conn.send(self.config['request_id'], self.config['payload']) self.conn.send(self.config['request_id'], self.config['payload'])
self.assertEqual(self.MockCreateConn.call_count, 1) self.assertEqual(self.MockCreateConn.call_count, 1)
self.conn._sock.sendall.assert_called_with(self.config['payload']) self.conn._sock.sendall.assert_called_with(self.config['payload'])
class TestKafkaConnection(unittest.TestCase):
@mock.patch('socket.create_connection')
def test_copy(self, socket):
"""KafkaConnection copies work as expected"""
conn = KafkaConnection('kafka', 9092)
self.assertEqual(socket.call_count, 1)
copy = conn.copy()
self.assertEqual(socket.call_count, 1)
self.assertEqual(copy.host, 'kafka')
self.assertEqual(copy.port, 9092)
self.assertEqual(copy._sock, None)
copy.reinit()
self.assertEqual(socket.call_count, 2)
self.assertNotEqual(copy._sock, None)
@mock.patch('socket.create_connection')
def test_copy_thread(self, socket):
"""KafkaConnection copies work in other threads"""
err = []
copy = KafkaConnection('kafka', 9092).copy()
def thread_func(err, copy):
try:
self.assertEqual(copy.host, 'kafka')
self.assertEqual(copy.port, 9092)
self.assertNotEqual(copy._sock, None)
except Exception as e:
err.append(e)
else:
err.append(None)
thread = Thread(target=thread_func, args=(err, copy))
thread.start()
thread.join()
self.assertEqual(err, [None])
self.assertEqual(socket.call_count, 2)