Merge pull request #602 from zackdever/KAFKA-2698
KAFKA-2698: add paused API
This commit is contained in:
@@ -528,6 +528,14 @@ class KafkaConsumer(six.Iterator):
|
||||
log.debug("Pausing partition %s", partition)
|
||||
self._subscription.pause(partition)
|
||||
|
||||
def paused(self):
|
||||
"""Get the partitions that were previously paused by a call to pause().
|
||||
|
||||
Returns:
|
||||
set: {partition (TopicPartition), ...}
|
||||
"""
|
||||
return self._subscription.paused_partitions()
|
||||
|
||||
def resume(self, *partitions):
|
||||
"""Resume fetching from the specified (paused) partitions.
|
||||
|
||||
|
||||
@@ -265,6 +265,11 @@ class SubscriptionState(object):
|
||||
"""Return set of TopicPartitions in current assignment."""
|
||||
return set(self.assignment.keys())
|
||||
|
||||
def paused_partitions(self):
|
||||
"""Return current set of paused TopicPartitions."""
|
||||
return set(partition for partition in self.assignment
|
||||
if self.is_paused(partition))
|
||||
|
||||
def fetchable_partitions(self):
|
||||
"""Return set of TopicPartitions that should be Fetched."""
|
||||
fetchable = set()
|
||||
|
||||
@@ -17,10 +17,13 @@ from test.conftest import version
|
||||
from test.testutil import random_string
|
||||
|
||||
|
||||
def get_connect_str(kafka_broker):
|
||||
return 'localhost:' + str(kafka_broker.port)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_client(kafka_broker):
|
||||
connect_str = 'localhost:' + str(kafka_broker.port)
|
||||
return SimpleClient(connect_str)
|
||||
return SimpleClient(get_connect_str(kafka_broker))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -37,8 +40,7 @@ def test_consumer(kafka_broker, version):
|
||||
if version >= (0, 8, 2) and version < (0, 9):
|
||||
topic(simple_client(kafka_broker))
|
||||
|
||||
connect_str = 'localhost:' + str(kafka_broker.port)
|
||||
consumer = KafkaConsumer(bootstrap_servers=connect_str)
|
||||
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
|
||||
consumer.poll(500)
|
||||
assert len(consumer._client._conns) > 0
|
||||
node_id = list(consumer._client._conns.keys())[0]
|
||||
@@ -49,7 +51,7 @@ def test_consumer(kafka_broker, version):
|
||||
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
|
||||
def test_group(kafka_broker, topic):
|
||||
num_partitions = 4
|
||||
connect_str = 'localhost:' + str(kafka_broker.port)
|
||||
connect_str = get_connect_str(kafka_broker)
|
||||
consumers = {}
|
||||
stop = {}
|
||||
threads = {}
|
||||
@@ -120,6 +122,24 @@ def test_group(kafka_broker, topic):
|
||||
threads[c].join()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
|
||||
def test_paused(kafka_broker, topic):
|
||||
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
|
||||
topics = [TopicPartition(topic, 1)]
|
||||
consumer.assign(topics)
|
||||
assert set(topics) == consumer.assignment()
|
||||
assert set() == consumer.paused()
|
||||
|
||||
consumer.pause(topics[0])
|
||||
assert set([topics[0]]) == consumer.paused()
|
||||
|
||||
consumer.resume(topics[0])
|
||||
assert set() == consumer.paused()
|
||||
|
||||
consumer.unsubscribe()
|
||||
assert set() == consumer.paused()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn(mocker):
|
||||
conn = mocker.patch('kafka.client_async.BrokerConnection')
|
||||
|
||||
Reference in New Issue
Block a user