444 lines
14 KiB
Python
444 lines
14 KiB
Python
from collections import defaultdict, deque
|
|
from functools import partial, wraps
|
|
import logging
|
|
import os
|
|
import socket
|
|
from threading import Event, Lock, Thread
|
|
import time
|
|
import traceback
|
|
import Queue
|
|
|
|
from cassandra import OperationTimedOut
|
|
from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown,
|
|
ConnectionBusy, NONBLOCKING,
|
|
MAX_STREAM_PER_CONNECTION)
|
|
from cassandra.decoder import RegisterMessage
|
|
from cassandra.marshal import int32_unpack
|
|
import cassandra.io.libevwrapper as libev
|
|
|
|
try:
|
|
from cStringIO import StringIO
|
|
except ImportError:
|
|
from StringIO import StringIO # ignore flake8 warning: # NOQA
|
|
|
|
try:
|
|
import ssl
|
|
except ImportError:
|
|
ssl = None # NOQA
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_loop = libev.Loop()
|
|
_loop_notifier = libev.Async(_loop)
|
|
_loop_notifier.start()
|
|
|
|
# prevent _loop_notifier from keeping the loop from returning
|
|
_loop.unref()
|
|
|
|
_loop_started = None
|
|
_loop_lock = Lock()
|
|
|
|
|
|
def _run_loop():
|
|
while True:
|
|
end_condition = _loop.start()
|
|
# there are still active watchers, no deadlock
|
|
with _loop_lock:
|
|
if end_condition:
|
|
log.debug("Restarting event loop")
|
|
continue
|
|
else:
|
|
# all Connections have been closed, no active watchers
|
|
log.debug("All Connections currently closed, event loop ended")
|
|
global _loop_started
|
|
_loop_started = False
|
|
break
|
|
|
|
|
|
def _start_loop():
|
|
global _loop_started
|
|
should_start = False
|
|
with _loop_lock:
|
|
if not _loop_started:
|
|
log.debug("Starting libev event loop")
|
|
_loop_started = True
|
|
should_start = True
|
|
|
|
if should_start:
|
|
t = Thread(target=_run_loop, name="event_loop")
|
|
t.daemon = True
|
|
t.start()
|
|
|
|
return should_start
|
|
|
|
|
|
def defunct_on_error(f):
|
|
|
|
@wraps(f)
|
|
def wrapper(self, *args, **kwargs):
|
|
try:
|
|
return f(self, *args, **kwargs)
|
|
except Exception as exc:
|
|
self.defunct(exc)
|
|
|
|
return wrapper
|
|
|
|
|
|
class LibevConnection(Connection):
|
|
"""
|
|
An implementation of :class:`.Connection` that uses libev for its event loop.
|
|
"""
|
|
|
|
# class-level set of all connections; only replaced with a new copy
|
|
# while holding _conn_set_lock, never modified in place
|
|
_live_conns = set()
|
|
# newly created connections that need their write/read watcher started
|
|
_new_conns = set()
|
|
# recently closed connections that need their write/read watcher stopped
|
|
_closed_conns = set()
|
|
_conn_set_lock = Lock()
|
|
|
|
_write_watcher_is_active = False
|
|
|
|
_total_reqd_bytes = 0
|
|
_read_watcher = None
|
|
_write_watcher = None
|
|
_socket = None
|
|
|
|
@classmethod
|
|
def factory(cls, *args, **kwargs):
|
|
timeout = kwargs.pop('timeout', 5.0)
|
|
conn = cls(*args, **kwargs)
|
|
conn.connected_event.wait(timeout)
|
|
if conn.last_error:
|
|
raise conn.last_error
|
|
elif not conn.connected_event.is_set():
|
|
conn.close()
|
|
raise OperationTimedOut("Timed out creating new connection")
|
|
else:
|
|
return conn
|
|
|
|
@classmethod
|
|
def _connection_created(cls, conn):
|
|
with cls._conn_set_lock:
|
|
new_live_conns = cls._live_conns.copy()
|
|
new_live_conns.add(conn)
|
|
cls._live_conns = new_live_conns
|
|
|
|
new_new_conns = cls._new_conns.copy()
|
|
new_new_conns.add(conn)
|
|
cls._new_conns = new_new_conns
|
|
|
|
@classmethod
|
|
def _connection_destroyed(cls, conn):
|
|
with cls._conn_set_lock:
|
|
new_live_conns = cls._live_conns.copy()
|
|
new_live_conns.discard(conn)
|
|
cls._live_conns = new_live_conns
|
|
|
|
new_closed_conns = cls._closed_conns.copy()
|
|
new_closed_conns.add(conn)
|
|
cls._closed_conns = new_closed_conns
|
|
|
|
@classmethod
|
|
def loop_will_run(cls, prepare):
|
|
changed = False
|
|
for conn in cls._live_conns:
|
|
if not conn.deque and conn._write_watcher_is_active:
|
|
if conn._write_watcher:
|
|
conn._write_watcher.stop()
|
|
conn._write_watcher_is_active = False
|
|
changed = True
|
|
elif conn.deque and not conn._write_watcher_is_active:
|
|
conn._write_watcher.start()
|
|
conn._write_watcher_is_active = True
|
|
changed = True
|
|
|
|
if cls._new_conns:
|
|
with cls._conn_set_lock:
|
|
to_start = cls._new_conns
|
|
cls._new_conns = set()
|
|
|
|
for conn in to_start:
|
|
conn._read_watcher.start()
|
|
|
|
changed = True
|
|
|
|
if cls._closed_conns:
|
|
with cls._conn_set_lock:
|
|
to_stop = cls._closed_conns
|
|
cls._closed_conns = set()
|
|
|
|
for conn in to_stop:
|
|
if conn._write_watcher:
|
|
conn._write_watcher.stop()
|
|
if conn._read_watcher:
|
|
conn._read_watcher.stop()
|
|
|
|
changed = True
|
|
|
|
if changed:
|
|
_loop_notifier.send()
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
Connection.__init__(self, *args, **kwargs)
|
|
|
|
self.connected_event = Event()
|
|
self._iobuf = StringIO()
|
|
|
|
self._callbacks = {}
|
|
self._push_watchers = defaultdict(set)
|
|
self.deque = deque()
|
|
self._deque_lock = Lock()
|
|
|
|
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
if self.ssl_options:
|
|
if not ssl:
|
|
raise Exception("This version of Python was not compiled with SSL support")
|
|
self._socket = ssl.wrap_socket(self._socket, **self.ssl_options)
|
|
self._socket.settimeout(1.0) # TODO potentially make this value configurable
|
|
self._socket.connect((self.host, self.port))
|
|
self._socket.setblocking(0)
|
|
|
|
if self.sockopts:
|
|
for args in self.sockopts:
|
|
self._socket.setsockopt(*args)
|
|
|
|
with _loop_lock:
|
|
self._read_watcher = libev.IO(self._socket._sock, libev.EV_READ, _loop, self.handle_read)
|
|
self._write_watcher = libev.IO(self._socket._sock, libev.EV_WRITE, _loop, self.handle_write)
|
|
|
|
self._send_options_message()
|
|
|
|
self.__class__._connection_created(self)
|
|
|
|
# start the global event loop if needed
|
|
_start_loop()
|
|
_loop_notifier.send()
|
|
|
|
def close(self):
|
|
with self.lock:
|
|
if self.is_closed:
|
|
return
|
|
self.is_closed = True
|
|
|
|
log.debug("Closing connection (%s) to %s", id(self), self.host)
|
|
self.__class__._connection_destroyed(self)
|
|
_loop_notifier.send()
|
|
self._socket.close()
|
|
|
|
# don't leave in-progress operations hanging
|
|
if not self.is_defunct:
|
|
self._error_all_callbacks(
|
|
ConnectionShutdown("Connection to %s was closed" % self.host))
|
|
|
|
def defunct(self, exc):
|
|
with self.lock:
|
|
if self.is_defunct or self.is_closed:
|
|
return
|
|
self.is_defunct = True
|
|
|
|
trace = traceback.format_exc(exc)
|
|
if trace != "None":
|
|
log.debug("Defuncting connection (%s) to %s: %s\n%s",
|
|
id(self), self.host, exc, traceback.format_exc(exc))
|
|
else:
|
|
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
|
|
|
|
self.last_error = exc
|
|
self.close()
|
|
self._error_all_callbacks(exc)
|
|
self.connected_event.set()
|
|
return exc
|
|
|
|
def _error_all_callbacks(self, exc):
|
|
with self.lock:
|
|
callbacks = self._callbacks
|
|
self._callbacks = {}
|
|
new_exc = ConnectionShutdown(str(exc))
|
|
for cb in callbacks.values():
|
|
try:
|
|
cb(new_exc)
|
|
except Exception:
|
|
log.warn("Ignoring unhandled exception while erroring callbacks for a "
|
|
"failed connection (%s) to host %s:",
|
|
id(self), self.host, exc_info=True)
|
|
|
|
def handle_write(self, watcher, revents, errno=None):
|
|
if revents & libev.EV_ERROR:
|
|
if errno:
|
|
exc = IOError(errno, os.strerror(errno))
|
|
else:
|
|
exc = Exception("libev reported an error")
|
|
|
|
self.defunct(exc)
|
|
return
|
|
|
|
while True:
|
|
try:
|
|
with self._deque_lock:
|
|
next_msg = self.deque.popleft()
|
|
except IndexError:
|
|
return
|
|
|
|
try:
|
|
sent = self._socket.send(next_msg)
|
|
except socket.error as err:
|
|
if (err.args[0] in NONBLOCKING):
|
|
with self._deque_lock:
|
|
self.deque.appendleft(next_msg)
|
|
else:
|
|
self.defunct(err)
|
|
return
|
|
else:
|
|
if sent < len(next_msg):
|
|
with self._deque_lock:
|
|
self.deque.appendleft(next_msg[sent:])
|
|
|
|
def handle_read(self, watcher, revents, errno=None):
|
|
if revents & libev.EV_ERROR:
|
|
if errno:
|
|
exc = IOError(errno, os.strerror(errno))
|
|
else:
|
|
exc = Exception("libev reported an error")
|
|
|
|
self.defunct(exc)
|
|
return
|
|
try:
|
|
while True:
|
|
buf = self._socket.recv(self.in_buffer_size)
|
|
self._iobuf.write(buf)
|
|
if len(buf) < self.in_buffer_size:
|
|
break
|
|
except socket.error as err:
|
|
if err.args[0] not in NONBLOCKING:
|
|
self.defunct(err)
|
|
return
|
|
|
|
if self._iobuf.tell():
|
|
while True:
|
|
pos = self._iobuf.tell()
|
|
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
|
# we don't have a complete header yet or we
|
|
# already saw a header, but we don't have a
|
|
# complete message yet
|
|
break
|
|
else:
|
|
# have enough for header, read body len from header
|
|
self._iobuf.seek(4)
|
|
body_len = int32_unpack(self._iobuf.read(4))
|
|
|
|
# seek to end to get length of current buffer
|
|
self._iobuf.seek(0, os.SEEK_END)
|
|
pos = self._iobuf.tell()
|
|
|
|
if pos >= body_len + 8:
|
|
# read message header and body
|
|
self._iobuf.seek(0)
|
|
msg = self._iobuf.read(8 + body_len)
|
|
|
|
# leave leftover in current buffer
|
|
leftover = self._iobuf.read()
|
|
self._iobuf = StringIO()
|
|
self._iobuf.write(leftover)
|
|
|
|
self._total_reqd_bytes = 0
|
|
self.process_msg(msg, body_len)
|
|
else:
|
|
self._total_reqd_bytes = body_len + 8
|
|
break
|
|
else:
|
|
log.debug("Connection %s closed by server", self)
|
|
self.close()
|
|
|
|
def handle_pushed(self, response):
|
|
log.debug("Message pushed from server: %r", response)
|
|
for cb in self._push_watchers.get(response.event_type, []):
|
|
try:
|
|
cb(response.event_args)
|
|
except Exception:
|
|
log.exception("Pushed event handler errored, ignoring:")
|
|
|
|
def push(self, data):
|
|
sabs = self.out_buffer_size
|
|
if len(data) > sabs:
|
|
chunks = []
|
|
for i in xrange(0, len(data), sabs):
|
|
chunks.append(data[i:i + sabs])
|
|
else:
|
|
chunks = [data]
|
|
|
|
with self._deque_lock:
|
|
self.deque.extend(chunks)
|
|
_loop_notifier.send()
|
|
|
|
def send_msg(self, msg, cb, wait_for_id=False):
|
|
if self.is_defunct:
|
|
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
|
|
elif self.is_closed:
|
|
raise ConnectionShutdown("Connection to %s is closed" % self.host)
|
|
|
|
if not wait_for_id:
|
|
try:
|
|
request_id = self._id_queue.get_nowait()
|
|
except Queue.Empty:
|
|
raise ConnectionBusy(
|
|
"Connection to %s is at the max number of requests" % self.host)
|
|
else:
|
|
request_id = self._id_queue.get()
|
|
|
|
self._callbacks[request_id] = cb
|
|
self.push(msg.to_string(request_id, compression=self.compressor))
|
|
return request_id
|
|
|
|
def wait_for_response(self, msg, timeout=None):
|
|
return self.wait_for_responses(msg, timeout=timeout)[0]
|
|
|
|
def wait_for_responses(self, *msgs, **kwargs):
|
|
timeout = kwargs.get('timeout')
|
|
waiter = ResponseWaiter(self, len(msgs))
|
|
|
|
# busy wait for sufficient space on the connection
|
|
messages_sent = 0
|
|
while True:
|
|
needed = len(msgs) - messages_sent
|
|
with self.lock:
|
|
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
|
|
self.in_flight += available
|
|
|
|
for i in range(messages_sent, messages_sent + available):
|
|
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
|
|
messages_sent += available
|
|
|
|
if messages_sent == len(msgs):
|
|
break
|
|
else:
|
|
if timeout is not None:
|
|
timeout -= 0.01
|
|
if timeout <= 0.0:
|
|
raise OperationTimedOut()
|
|
time.sleep(0.01)
|
|
|
|
try:
|
|
return waiter.deliver(timeout)
|
|
except OperationTimedOut:
|
|
raise
|
|
except Exception, exc:
|
|
self.defunct(exc)
|
|
raise
|
|
|
|
def register_watcher(self, event_type, callback):
|
|
self._push_watchers[event_type].add(callback)
|
|
self.wait_for_response(RegisterMessage(event_list=[event_type]))
|
|
|
|
def register_watchers(self, type_callback_dict):
|
|
for event_type, callback in type_callback_dict.items():
|
|
self._push_watchers[event_type].add(callback)
|
|
self.wait_for_response(RegisterMessage(event_list=type_callback_dict.keys()))
|
|
|
|
|
|
_preparer = libev.Prepare(_loop, LibevConnection.loop_will_run)
|
|
# prevent _preparer from keeping the loop from returning
|
|
_loop.unref()
|
|
_preparer.start()
|