Returned original tests, rm dirty flag, name fixes

This commit is contained in:
Viktor Shlapakov
2015-02-25 10:45:47 +03:00
parent bc0d5c1e27
commit 1cce287157
3 changed files with 60 additions and 24 deletions

View File

@@ -170,7 +170,6 @@ class KafkaConnection(local):
c.port = copy.copy(self.port) c.port = copy.copy(self.port)
c.timeout = copy.copy(self.timeout) c.timeout = copy.copy(self.timeout)
c._sock = None c._sock = None
c._dirty = True
return c return c
def close(self): def close(self):

View File

@@ -3,11 +3,10 @@ from __future__ import absolute_import
import logging import logging
import time import time
from Queue import Queue
try: try:
from queue import Empty from queue import Empty, Queue
except ImportError: except ImportError:
from Queue import Empty from Queue import Empty, Queue
from collections import defaultdict from collections import defaultdict
from threading import Thread from threading import Thread
@@ -33,13 +32,8 @@ def _send_upstream(queue, client, codec, batch_time, batch_size,
Listen on the queue for a specified number of messages or till Listen on the queue for a specified number of messages or till
a specified timeout and send them upstream to the brokers in one a specified timeout and send them upstream to the brokers in one
request request
NOTE: Ideally, this should have been a method inside the Producer
class. However, multiprocessing module has issues in windows. The
functionality breaks unless this function is kept outside of a class
""" """
stop = False stop = False
client.reinit()
while not stop: while not stop:
timeout = batch_time timeout = batch_time
@@ -142,7 +136,7 @@ class Producer(object):
log.warning("Current implementation does not retry Failed messages") log.warning("Current implementation does not retry Failed messages")
log.warning("Use at your own risk! (or help improve with a PR!)") log.warning("Use at your own risk! (or help improve with a PR!)")
self.queue = Queue() # Messages are sent through this queue self.queue = Queue() # Messages are sent through this queue
self.proc = Thread(target=_send_upstream, self.thread = Thread(target=_send_upstream,
args=(self.queue, args=(self.queue,
self.client.copy(), self.client.copy(),
self.codec, self.codec,
@@ -151,9 +145,11 @@ class Producer(object):
self.req_acks, self.req_acks,
self.ack_timeout)) self.ack_timeout))
# Process will die if main thread exits # Thread will die if main thread exits
self.proc.daemon = True self.thread.daemon = True
self.proc.start() self.thread.start()
def send_messages(self, topic, partition, *msg): def send_messages(self, topic, partition, *msg):
""" """
@@ -210,7 +206,4 @@ class Producer(object):
""" """
if self.async: if self.async:
self.queue.put((STOP_ASYNC_PRODUCER, None, None)) self.queue.put((STOP_ASYNC_PRODUCER, None, None))
self.proc.join(timeout) self.thread.join(timeout)
if self.proc.is_alive():
raise SystemError("Can't join Kafka async thread")

View File

@@ -1,5 +1,6 @@
import socket import socket
import struct import struct
from threading import Thread
import mock import mock
from . import unittest from . import unittest
@@ -162,3 +163,46 @@ class ConnTest(unittest.TestCase):
self.conn.send(self.config['request_id'], self.config['payload']) self.conn.send(self.config['request_id'], self.config['payload'])
self.assertEqual(self.MockCreateConn.call_count, 1) self.assertEqual(self.MockCreateConn.call_count, 1)
self.conn._sock.sendall.assert_called_with(self.config['payload']) self.conn._sock.sendall.assert_called_with(self.config['payload'])
class TestKafkaConnection(unittest.TestCase):
@mock.patch('socket.create_connection')
def test_copy(self, socket):
"""KafkaConnection copies work as expected"""
conn = KafkaConnection('kafka', 9092)
self.assertEqual(socket.call_count, 1)
copy = conn.copy()
self.assertEqual(socket.call_count, 1)
self.assertEqual(copy.host, 'kafka')
self.assertEqual(copy.port, 9092)
self.assertEqual(copy._sock, None)
copy.reinit()
self.assertEqual(socket.call_count, 2)
self.assertNotEqual(copy._sock, None)
@mock.patch('socket.create_connection')
def test_copy_thread(self, socket):
"""KafkaConnection copies work in other threads"""
err = []
copy = KafkaConnection('kafka', 9092).copy()
def thread_func(err, copy):
try:
self.assertEqual(copy.host, 'kafka')
self.assertEqual(copy.port, 9092)
self.assertNotEqual(copy._sock, None)
except Exception as e:
err.append(e)
else:
err.append(None)
thread = Thread(target=thread_func, args=(err, copy))
thread.start()
thread.join()
self.assertEqual(err, [None])
self.assertEqual(socket.call_count, 2)