Use weakref when registering a producer.close atexit to fix normal gc (#728)

* Use weakref when registering a producer.close atexit to fix normal gc
* Test that del(producer) terminates async thread
This commit is contained in:
Dana Powers
2016-06-18 14:51:23 -07:00
committed by GitHub
parent 5b9c55817b
commit 6271c02c6e
3 changed files with 61 additions and 3 deletions

View File

@@ -5,12 +5,13 @@ import copy
import logging import logging
import threading import threading
import time import time
import weakref
from .. import errors as Errors
from ..client_async import KafkaClient from ..client_async import KafkaClient
from ..structs import TopicPartition
from ..partitioner.default import DefaultPartitioner from ..partitioner.default import DefaultPartitioner
from ..protocol.message import Message, MessageSet from ..protocol.message import Message, MessageSet
from .. import errors as Errors from ..structs import TopicPartition
from .future import FutureRecordMetadata, FutureProduceResult from .future import FutureRecordMetadata, FutureProduceResult
from .record_accumulator import AtomicInteger, RecordAccumulator from .record_accumulator import AtomicInteger, RecordAccumulator
from .sender import Sender from .sender import Sender
@@ -293,14 +294,47 @@ class KafkaProducer(object):
self._sender.daemon = True self._sender.daemon = True
self._sender.start() self._sender.start()
self._closed = False self._closed = False
atexit.register(self.close, timeout=0)
self._cleanup = self._cleanup_factory()
atexit.register(self._cleanup)
log.debug("Kafka producer started") log.debug("Kafka producer started")
def _cleanup_factory(self):
"""Build a cleanup clojure that doesn't increase our ref count"""
_self = weakref.proxy(self)
def wrapper():
try:
_self.close()
except (ReferenceError, AttributeError):
pass
return wrapper
def _unregister_cleanup(self):
if getattr(self, '_cleanup'):
if hasattr(atexit, 'unregister'):
atexit.unregister(self._cleanup) # pylint: disable=no-member
# py2 requires removing from private attribute...
else:
# ValueError on list.remove() if the exithandler no longer exists
# but that is fine here
try:
atexit._exithandlers.remove( # pylint: disable=no-member
(self._cleanup, (), {}))
except ValueError:
pass
self._cleanup = None
def __del__(self): def __del__(self):
self.close(timeout=0) self.close(timeout=0)
def close(self, timeout=None): def close(self, timeout=None):
"""Close this producer.""" """Close this producer."""
# drop our atexit handler now to avoid leaks
self._unregister_cleanup()
if not hasattr(self, '_closed') or self._closed: if not hasattr(self, '_closed') or self._closed:
log.info('Kafka producer closed') log.info('Kafka producer closed')
return return

View File

@@ -1,3 +1,4 @@
import atexit
import binascii import binascii
import collections import collections
import struct import struct
@@ -188,3 +189,12 @@ class WeakMethod(object):
if not isinstance(other, WeakMethod): if not isinstance(other, WeakMethod):
return False return False
return self._target_id == other._target_id and self._method_id == other._method_id return self._target_id == other._target_id and self._method_id == other._method_id
def try_method_on_system_exit(obj, method, *args, **kwargs):
def wrapper(_obj, _meth, *args, **kwargs):
try:
getattr(_obj, _meth)(*args, **kwargs)
except (ReferenceError, AttributeError):
pass
atexit.register(wrapper, weakref.proxy(obj), method, *args, **kwargs)

View File

@@ -1,4 +1,7 @@
import gc
import platform
import sys import sys
import threading
import pytest import pytest
@@ -64,3 +67,14 @@ def test_end_to_end(kafka_broker, compression):
break break
assert msgs == set(['msg %d' % i for i in range(messages)]) assert msgs == set(['msg %d' % i for i in range(messages)])
@pytest.mark.skipif(platform.python_implementation() != 'CPython',
reason='Test relies on CPython-specific gc policies')
def test_kafka_producer_gc_cleanup():
threads = threading.active_count()
producer = KafkaProducer(api_version='0.9') # set api_version explicitly to avoid auto-detection
assert threading.active_count() == threads + 1
del(producer)
gc.collect()
assert threading.active_count() == threads