Just encode to bytes if it isn't bytes.
This commit is contained in:
@@ -70,10 +70,12 @@ class SimpleProducer(Producer):
|
||||
|
||||
def send_messages(self, topic, *msg):
|
||||
if not isinstance(topic, six.binary_type):
|
||||
raise TypeError("topic must be type bytes")
|
||||
topic = topic.encode('utf-8')
|
||||
|
||||
partition = self._next_partition(topic)
|
||||
return super(SimpleProducer, self).send_messages(topic, partition, *msg)
|
||||
return super(SimpleProducer, self).send_messages(
|
||||
topic, partition, *msg
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return '<SimpleProducer batch=%s>' % self.async
|
||||
|
@@ -7,6 +7,7 @@ from . import unittest
|
||||
|
||||
from kafka.producer.base import Producer
|
||||
|
||||
|
||||
class TestKafkaProducer(unittest.TestCase):
|
||||
def test_producer_message_types(self):
|
||||
|
||||
@@ -28,11 +29,14 @@ class TestKafkaProducer(unittest.TestCase):
|
||||
def test_topic_message_types(self):
|
||||
from kafka.producer.simple import SimpleProducer
|
||||
|
||||
producer = SimpleProducer(MagicMock())
|
||||
topic = "test-topic"
|
||||
partition = 0
|
||||
client = MagicMock()
|
||||
|
||||
def send_message():
|
||||
producer.send_messages(topic, partition, b'hi')
|
||||
def partitions(topic):
|
||||
return [0, 1]
|
||||
|
||||
self.assertRaises(TypeError, send_message)
|
||||
client.get_partition_ids_for_topic = partitions
|
||||
|
||||
producer = SimpleProducer(client, random_start=False)
|
||||
topic = b"test-topic"
|
||||
producer.send_messages(topic, b'hi')
|
||||
assert client.send_produce_request.called
|
||||
|
Reference in New Issue
Block a user