Files
deb-python-cassandra-driver/cassandra/io/asyncorereactor.py
2013-07-02 16:49:23 -05:00

254 lines
7.3 KiB
Python

from collections import defaultdict, deque
from functools import partial
import logging
import socket
import sys
from threading import Event, Lock, Thread
import traceback
from Queue import Queue
import asyncore
from cassandra.connection import (Connection, ResponseWaiter, ConnectionException,
ConnectionBusy, NONBLOCKING)
from cassandra.marshal import int32_unpack
log = logging.getLogger(__name__)
_loop_started = None
_loop_lock = Lock()
def _run_loop():
log.debug("Starting asyncore event loop")
asyncore.loop(timeout=0.001, use_poll=True, count=None)
log.debug("Asyncore event loop ended")
with _loop_lock:
global _loop_started
_loop_started = False
def _start_loop():
global _loop_started
should_start = False
with _loop_lock:
if not _loop_started:
_loop_started = True
should_start = True
if should_start:
t = Thread(target=_run_loop, name="event_loop")
t.daemon = True
t.start()
return should_start
class AsyncoreConnection(Connection, asyncore.dispatcher):
"""
An implementation of :class:`.Connection` that utilizes the ``asyncore``
module in the Python standard library for its event loop.
"""
_buf = ""
_total_reqd_bytes = 0
_writable = False
_readable = False
_have_listeners = False
@classmethod
def factory(cls, *args, **kwargs):
conn = cls(*args, **kwargs)
conn.connected_event.wait()
if conn.last_error:
raise conn.last_error
else:
return conn
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
asyncore.dispatcher.__init__(self)
self.connected_event = Event()
self._callbacks = {}
self._push_watchers = defaultdict(set)
self.deque = deque()
log.debug("Opening socket to %s", self.host)
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.connect((self.host, self.port))
if self.sockopts:
for args in self.sockopts:
self.socket.setsockopt(*args)
self._writable = True
self._readable = True
# start the global event loop if needed
_start_loop()
def close(self):
with self.lock:
if self.is_closed:
return
self.is_closed = True
log.debug("Closing connection to %s" % (self.host,))
self._writable = False
self._readable = False
asyncore.dispatcher.close(self)
log.debug("Closed socket to %s" % (self.host,))
# don't leave in-progress operations hanging
self.connected_event.set()
if not self.is_defunct:
self._error_all_callbacks(
ConnectionException("Connection to %s was closed" % self.host))
def __del__(self):
self.close()
def defunct(self, exc):
if self.is_defunct:
return
log.debug("Defuncting connection to %s: %s\n%s" %
(self.host, exc, traceback.format_exc(exc)))
self.last_error = exc
self.is_defunct = True
self._error_all_callbacks(exc)
self.connected_event.set()
return exc
def _error_all_callbacks(self, exc):
for cb in self._callbacks.values():
cb(exc)
def handle_connect(self):
self._send_options_message()
def handle_error(self):
self.defunct(sys.exc_info()[1])
def handle_close(self):
log.debug("connection closed by server")
self.close()
def handle_write(self):
try:
next_msg = self.deque.popleft()
except IndexError:
self._writable = False
return
try:
sent = self.send(next_msg)
except socket.error, err:
if (err.args[0] in NONBLOCKING):
self.deque.appendleft(next_msg)
else:
self.defunct(err)
return
else:
if sent < len(next_msg):
self.deque.appendleft(next_msg[sent:])
if not self.deque:
self._writable = False
self._readable = True
def handle_read(self):
try:
buf = self.recv(self.in_buffer_size)
except socket.error, err:
if err.args[0] not in NONBLOCKING:
self.defunct(err)
return
if buf:
self._buf += buf
while True:
if len(self._buf) < 8:
# we don't have a complete header yet
break
elif self._total_reqd_bytes and len(self._buf) < self._total_reqd_bytes:
# we already saw a header, but we don't have a complete message yet
break
else:
body_len = int32_unpack(self._buf[4:8])
if len(self._buf) - 8 >= body_len:
msg = self._buf[:8 + body_len]
self._buf = self._buf[8 + body_len:]
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + 8
if not self._callbacks:
self._readable = False
else:
log.debug("connection closed by server")
self.close()
def handle_pushed(self, response):
for cb in self._push_watchers[response.type]:
try:
cb(response)
except:
log.exception("Pushed event handler errored, ignoring:")
def push(self, data):
sabs = self.out_buffer_size
if len(data) > sabs:
chunks = []
for i in xrange(0, len(data), sabs):
chunks.append(data[i:i + sabs])
else:
chunks = [data]
with self.lock:
self.deque.extend(chunks)
self._writable = True
def writable(self):
return self._writable
def readable(self):
return self._readable or (self._have_listeners and not self.is_defunct or self.is_closed)
def send_msg(self, msg, cb):
if self.is_defunct:
raise ConnectionException("Connection to %s is defunct" % self.host)
elif self.is_closed:
raise ConnectionException("Connection to %s is closed" % self.host)
try:
request_id = self._id_queue.get_nowait()
except Queue.EMPTY:
raise ConnectionBusy(
"Connection to %s is at the max number of requests" % self.host)
self._callbacks[request_id] = cb
self.push(msg.to_string(request_id, compression=self.compressor))
return request_id
def wait_for_response(self, msg):
return self.wait_for_responses(msg)[0]
def wait_for_responses(self, *msgs):
waiter = ResponseWaiter(len(msgs))
for i, msg in enumerate(msgs):
self.send_msg(msg, partial(waiter.got_response, index=i))
return waiter.deliver()
def register_watcher(self, event_type, callback):
self._push_watchers[event_type].add(callback)
self._have_listeners = True
def register_watchers(self, type_callback_dict):
for event_type, callback in type_callback_dict.items():
self.register_watcher(event_type, callback)