
Under certain error conditions, this could result in an FD leak for the socket handles. (In my testing, they seemed to eventually be GC'ed and closed, but this behavior is probably not dependable.) Fixes #80
444 lines
14 KiB
Python
444 lines
14 KiB
Python
from collections import defaultdict, deque
|
|
from functools import partial, wraps
|
|
import logging
|
|
import os
|
|
import socket
|
|
from threading import Event, Lock, Thread
|
|
import time
|
|
import traceback
|
|
import Queue
|
|
|
|
from cassandra import OperationTimedOut
|
|
from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown,
|
|
ConnectionBusy, NONBLOCKING,
|
|
MAX_STREAM_PER_CONNECTION)
|
|
from cassandra.decoder import RegisterMessage
|
|
from cassandra.marshal import int32_unpack
|
|
import cassandra.io.libevwrapper as libev
|
|
|
|
try:
|
|
from cStringIO import StringIO
|
|
except ImportError:
|
|
from StringIO import StringIO # ignore flake8 warning: # NOQA
|
|
|
|
try:
|
|
import ssl
|
|
except ImportError:
|
|
ssl = None # NOQA
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_loop = libev.Loop()
|
|
_loop_notifier = libev.Async(_loop)
|
|
_loop_notifier.start()
|
|
|
|
# prevent _loop_notifier from keeping the loop from returning
|
|
_loop.unref()
|
|
|
|
_loop_started = None
|
|
_loop_lock = Lock()
|
|
|
|
|
|
def _run_loop():
|
|
while True:
|
|
end_condition = _loop.start()
|
|
# there are still active watchers, no deadlock
|
|
with _loop_lock:
|
|
if end_condition:
|
|
log.debug("Restarting event loop")
|
|
continue
|
|
else:
|
|
# all Connections have been closed, no active watchers
|
|
log.debug("All Connections currently closed, event loop ended")
|
|
global _loop_started
|
|
_loop_started = False
|
|
break
|
|
|
|
|
|
def _start_loop():
|
|
global _loop_started
|
|
should_start = False
|
|
with _loop_lock:
|
|
if not _loop_started:
|
|
log.debug("Starting libev event loop")
|
|
_loop_started = True
|
|
should_start = True
|
|
|
|
if should_start:
|
|
t = Thread(target=_run_loop, name="event_loop")
|
|
t.daemon = True
|
|
t.start()
|
|
|
|
return should_start
|
|
|
|
|
|
def defunct_on_error(f):
|
|
|
|
@wraps(f)
|
|
def wrapper(self, *args, **kwargs):
|
|
try:
|
|
return f(self, *args, **kwargs)
|
|
except Exception as exc:
|
|
self.defunct(exc)
|
|
|
|
return wrapper
|
|
|
|
|
|
class LibevConnection(Connection):
|
|
"""
|
|
An implementation of :class:`.Connection` that uses libev for its event loop.
|
|
"""
|
|
|
|
# class-level set of all connections; only replaced with a new copy
|
|
# while holding _conn_set_lock, never modified in place
|
|
_live_conns = set()
|
|
# newly created connections that need their write/read watcher started
|
|
_new_conns = set()
|
|
# recently closed connections that need their write/read watcher stopped
|
|
_closed_conns = set()
|
|
_conn_set_lock = Lock()
|
|
|
|
_write_watcher_is_active = False
|
|
|
|
_total_reqd_bytes = 0
|
|
_read_watcher = None
|
|
_write_watcher = None
|
|
_socket = None
|
|
|
|
@classmethod
|
|
def factory(cls, *args, **kwargs):
|
|
timeout = kwargs.pop('timeout', 5.0)
|
|
conn = cls(*args, **kwargs)
|
|
conn.connected_event.wait(timeout)
|
|
if conn.last_error:
|
|
raise conn.last_error
|
|
elif not conn.connected_event.is_set():
|
|
conn.close()
|
|
raise OperationTimedOut("Timed out creating new connection")
|
|
else:
|
|
return conn
|
|
|
|
@classmethod
|
|
def _connection_created(cls, conn):
|
|
with cls._conn_set_lock:
|
|
new_live_conns = cls._live_conns.copy()
|
|
new_live_conns.add(conn)
|
|
cls._live_conns = new_live_conns
|
|
|
|
new_new_conns = cls._new_conns.copy()
|
|
new_new_conns.add(conn)
|
|
cls._new_conns = new_new_conns
|
|
|
|
@classmethod
|
|
def _connection_destroyed(cls, conn):
|
|
with cls._conn_set_lock:
|
|
new_live_conns = cls._live_conns.copy()
|
|
new_live_conns.discard(conn)
|
|
cls._live_conns = new_live_conns
|
|
|
|
new_closed_conns = cls._closed_conns.copy()
|
|
new_closed_conns.add(conn)
|
|
cls._closed_conns = new_closed_conns
|
|
|
|
@classmethod
|
|
def loop_will_run(cls, prepare):
|
|
changed = False
|
|
for conn in cls._live_conns:
|
|
if not conn.deque and conn._write_watcher_is_active:
|
|
if conn._write_watcher:
|
|
conn._write_watcher.stop()
|
|
conn._write_watcher_is_active = False
|
|
changed = True
|
|
elif conn.deque and not conn._write_watcher_is_active:
|
|
conn._write_watcher.start()
|
|
conn._write_watcher_is_active = True
|
|
changed = True
|
|
|
|
if cls._new_conns:
|
|
with cls._conn_set_lock:
|
|
to_start = cls._new_conns
|
|
cls._new_conns = set()
|
|
|
|
for conn in to_start:
|
|
conn._read_watcher.start()
|
|
|
|
changed = True
|
|
|
|
if cls._closed_conns:
|
|
with cls._conn_set_lock:
|
|
to_stop = cls._closed_conns
|
|
cls._closed_conns = set()
|
|
|
|
for conn in to_stop:
|
|
if conn._write_watcher:
|
|
conn._write_watcher.stop()
|
|
if conn._read_watcher:
|
|
conn._read_watcher.stop()
|
|
|
|
changed = True
|
|
|
|
if changed:
|
|
_loop_notifier.send()
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
Connection.__init__(self, *args, **kwargs)
|
|
|
|
self.connected_event = Event()
|
|
self._iobuf = StringIO()
|
|
|
|
self._callbacks = {}
|
|
self._push_watchers = defaultdict(set)
|
|
self.deque = deque()
|
|
self._deque_lock = Lock()
|
|
|
|
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
if self.ssl_options:
|
|
if not ssl:
|
|
raise Exception("This version of Python was not compiled with SSL support")
|
|
self._socket = ssl.wrap_socket(self._socket, **self.ssl_options)
|
|
self._socket.settimeout(1.0) # TODO potentially make this value configurable
|
|
self._socket.connect((self.host, self.port))
|
|
self._socket.setblocking(0)
|
|
|
|
if self.sockopts:
|
|
for args in self.sockopts:
|
|
self._socket.setsockopt(*args)
|
|
|
|
with _loop_lock:
|
|
self._read_watcher = libev.IO(self._socket._sock, libev.EV_READ, _loop, self.handle_read)
|
|
self._write_watcher = libev.IO(self._socket._sock, libev.EV_WRITE, _loop, self.handle_write)
|
|
|
|
self._send_options_message()
|
|
|
|
self.__class__._connection_created(self)
|
|
|
|
# start the global event loop if needed
|
|
_start_loop()
|
|
_loop_notifier.send()
|
|
|
|
def close(self):
|
|
with self.lock:
|
|
if self.is_closed:
|
|
return
|
|
self.is_closed = True
|
|
|
|
log.debug("Closing connection (%s) to %s", id(self), self.host)
|
|
self.__class__._connection_destroyed(self)
|
|
_loop_notifier.send()
|
|
self._socket.close()
|
|
|
|
# don't leave in-progress operations hanging
|
|
if not self.is_defunct:
|
|
self._error_all_callbacks(
|
|
ConnectionShutdown("Connection to %s was closed" % self.host))
|
|
|
|
def defunct(self, exc):
|
|
with self.lock:
|
|
if self.is_defunct or self.is_closed:
|
|
return
|
|
self.is_defunct = True
|
|
|
|
trace = traceback.format_exc(exc)
|
|
if trace != "None":
|
|
log.debug("Defuncting connection (%s) to %s: %s\n%s",
|
|
id(self), self.host, exc, traceback.format_exc(exc))
|
|
else:
|
|
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
|
|
|
|
self.last_error = exc
|
|
self.close()
|
|
self._error_all_callbacks(exc)
|
|
self.connected_event.set()
|
|
return exc
|
|
|
|
def _error_all_callbacks(self, exc):
|
|
with self.lock:
|
|
callbacks = self._callbacks
|
|
self._callbacks = {}
|
|
new_exc = ConnectionShutdown(str(exc))
|
|
for cb in callbacks.values():
|
|
try:
|
|
cb(new_exc)
|
|
except Exception:
|
|
log.warn("Ignoring unhandled exception while erroring callbacks for a "
|
|
"failed connection (%s) to host %s:",
|
|
id(self), self.host, exc_info=True)
|
|
|
|
def handle_write(self, watcher, revents, errno=None):
|
|
if revents & libev.EV_ERROR:
|
|
if errno:
|
|
exc = IOError(errno, os.strerror(errno))
|
|
else:
|
|
exc = Exception("libev reported an error")
|
|
|
|
self.defunct(exc)
|
|
return
|
|
|
|
while True:
|
|
try:
|
|
with self._deque_lock:
|
|
next_msg = self.deque.popleft()
|
|
except IndexError:
|
|
return
|
|
|
|
try:
|
|
sent = self._socket.send(next_msg)
|
|
except socket.error as err:
|
|
if (err.args[0] in NONBLOCKING):
|
|
with self._deque_lock:
|
|
self.deque.appendleft(next_msg)
|
|
else:
|
|
self.defunct(err)
|
|
return
|
|
else:
|
|
if sent < len(next_msg):
|
|
with self._deque_lock:
|
|
self.deque.appendleft(next_msg[sent:])
|
|
|
|
def handle_read(self, watcher, revents, errno=None):
|
|
if revents & libev.EV_ERROR:
|
|
if errno:
|
|
exc = IOError(errno, os.strerror(errno))
|
|
else:
|
|
exc = Exception("libev reported an error")
|
|
|
|
self.defunct(exc)
|
|
return
|
|
try:
|
|
while True:
|
|
buf = self._socket.recv(self.in_buffer_size)
|
|
self._iobuf.write(buf)
|
|
if len(buf) < self.in_buffer_size:
|
|
break
|
|
except socket.error as err:
|
|
if err.args[0] not in NONBLOCKING:
|
|
self.defunct(err)
|
|
return
|
|
|
|
if self._iobuf.tell():
|
|
while True:
|
|
pos = self._iobuf.tell()
|
|
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
|
# we don't have a complete header yet or we
|
|
# already saw a header, but we don't have a
|
|
# complete message yet
|
|
break
|
|
else:
|
|
# have enough for header, read body len from header
|
|
self._iobuf.seek(4)
|
|
body_len = int32_unpack(self._iobuf.read(4))
|
|
|
|
# seek to end to get length of current buffer
|
|
self._iobuf.seek(0, os.SEEK_END)
|
|
pos = self._iobuf.tell()
|
|
|
|
if pos >= body_len + 8:
|
|
# read message header and body
|
|
self._iobuf.seek(0)
|
|
msg = self._iobuf.read(8 + body_len)
|
|
|
|
# leave leftover in current buffer
|
|
leftover = self._iobuf.read()
|
|
self._iobuf = StringIO()
|
|
self._iobuf.write(leftover)
|
|
|
|
self._total_reqd_bytes = 0
|
|
self.process_msg(msg, body_len)
|
|
else:
|
|
self._total_reqd_bytes = body_len + 8
|
|
break
|
|
else:
|
|
log.debug("Connection %s closed by server", self)
|
|
self.close()
|
|
|
|
def handle_pushed(self, response):
|
|
log.debug("Message pushed from server: %r", response)
|
|
for cb in self._push_watchers.get(response.event_type, []):
|
|
try:
|
|
cb(response.event_args)
|
|
except Exception:
|
|
log.exception("Pushed event handler errored, ignoring:")
|
|
|
|
def push(self, data):
|
|
sabs = self.out_buffer_size
|
|
if len(data) > sabs:
|
|
chunks = []
|
|
for i in xrange(0, len(data), sabs):
|
|
chunks.append(data[i:i + sabs])
|
|
else:
|
|
chunks = [data]
|
|
|
|
with self._deque_lock:
|
|
self.deque.extend(chunks)
|
|
_loop_notifier.send()
|
|
|
|
def send_msg(self, msg, cb, wait_for_id=False):
|
|
if self.is_defunct:
|
|
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
|
|
elif self.is_closed:
|
|
raise ConnectionShutdown("Connection to %s is closed" % self.host)
|
|
|
|
if not wait_for_id:
|
|
try:
|
|
request_id = self._id_queue.get_nowait()
|
|
except Queue.Empty:
|
|
raise ConnectionBusy(
|
|
"Connection to %s is at the max number of requests" % self.host)
|
|
else:
|
|
request_id = self._id_queue.get()
|
|
|
|
self._callbacks[request_id] = cb
|
|
self.push(msg.to_string(request_id, compression=self.compressor))
|
|
return request_id
|
|
|
|
def wait_for_response(self, msg, timeout=None):
|
|
return self.wait_for_responses(msg, timeout=timeout)[0]
|
|
|
|
def wait_for_responses(self, *msgs, **kwargs):
|
|
timeout = kwargs.get('timeout')
|
|
waiter = ResponseWaiter(self, len(msgs))
|
|
|
|
# busy wait for sufficient space on the connection
|
|
messages_sent = 0
|
|
while True:
|
|
needed = len(msgs) - messages_sent
|
|
with self.lock:
|
|
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
|
|
self.in_flight += available
|
|
|
|
for i in range(messages_sent, messages_sent + available):
|
|
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
|
|
messages_sent += available
|
|
|
|
if messages_sent == len(msgs):
|
|
break
|
|
else:
|
|
if timeout is not None:
|
|
timeout -= 0.01
|
|
if timeout <= 0.0:
|
|
raise OperationTimedOut()
|
|
time.sleep(0.01)
|
|
|
|
try:
|
|
return waiter.deliver(timeout)
|
|
except OperationTimedOut:
|
|
raise
|
|
except Exception, exc:
|
|
self.defunct(exc)
|
|
raise
|
|
|
|
def register_watcher(self, event_type, callback):
|
|
self._push_watchers[event_type].add(callback)
|
|
self.wait_for_response(RegisterMessage(event_list=[event_type]))
|
|
|
|
def register_watchers(self, type_callback_dict):
|
|
for event_type, callback in type_callback_dict.items():
|
|
self._push_watchers[event_type].add(callback)
|
|
self.wait_for_response(RegisterMessage(event_list=type_callback_dict.keys()))
|
|
|
|
|
|
_preparer = libev.Prepare(_loop, LibevConnection.loop_will_run)
|
|
# prevent _preparer from keeping the loop from returning
|
|
_loop.unref()
|
|
_preparer.start()
|