from collections import defaultdict, deque from functools import partial import logging import socket import sys from threading import Event, Lock, Thread import traceback from Queue import Queue import asyncore from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown, ConnectionBusy, NONBLOCKING) from cassandra.decoder import RegisterMessage from cassandra.marshal import int32_unpack log = logging.getLogger(__name__) _loop_started = False _loop_lock = Lock() _starting_conns = set() _starting_conns_lock = Lock() def _run_loop(): global _loop_started log.debug("Starting asyncore event loop") with _loop_lock: while True: try: asyncore.loop(timeout=0.001, use_poll=True, count=None) except Exception: log.debug("Asyncore event loop stopped unexepectedly", exc_info=True) break with _starting_conns_lock: if not _starting_conns: break _loop_started = False if log: # this can happen during interpreter shutdown log.debug("Asyncore event loop ended") def _start_loop(): global _loop_started should_start = False did_acquire = False try: did_acquire = _loop_lock.acquire(False) if did_acquire and not _loop_started: _loop_started = True should_start = True finally: if did_acquire: _loop_lock.release() if should_start: t = Thread(target=_run_loop, name="event_loop") t.daemon = True t.start() class AsyncoreConnection(Connection, asyncore.dispatcher): """ An implementation of :class:`.Connection` that utilizes the ``asyncore`` module in the Python standard library for its event loop. """ _buf = "" _total_reqd_bytes = 0 _writable = False _readable = False _have_listeners = False @classmethod def factory(cls, *args, **kwargs): conn = cls(*args, **kwargs) conn.connected_event.wait() if conn.last_error: raise conn.last_error else: return conn def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) asyncore.dispatcher.__init__(self) self.connected_event = Event() self._callbacks = {} self._push_watchers = defaultdict(set) self.deque = deque() with _starting_conns_lock: _starting_conns.add(self) log.debug("Opening socket to %s", self.host) self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.connect((self.host, self.port)) if self.sockopts: for args in self.sockopts: self.socket.setsockopt(*args) self._writable = True self._readable = True # start the global event loop if needed _start_loop() def close(self): with self.lock: if self.is_closed: return self.is_closed = True log.debug("Closing connection to %s" % (self.host,)) self._writable = False self._readable = False asyncore.dispatcher.close(self) log.debug("Closed socket to %s" % (self.host,)) with _starting_conns_lock: _starting_conns.discard(self) # don't leave in-progress operations hanging self.connected_event.set() if not self.is_defunct: self._error_all_callbacks( ConnectionShutdown("Connection to %s was closed" % self.host)) def __del__(self): try: self.close() except TypeError: pass def defunct(self, exc): with self.lock: if self.is_defunct: return self.is_defunct = True trace = traceback.format_exc(exc) if trace != "None": log.debug("Defuncting connection to %s: %s\n%s", self.host, exc, traceback.format_exc(exc)) else: log.debug("Defuncting connection to %s: %s", self.host, exc) self.last_error = exc self._error_all_callbacks(exc) self.connected_event.set() return exc def _error_all_callbacks(self, exc): new_exc = ConnectionShutdown(str(exc)) for cb in self._callbacks.values(): cb(new_exc) def handle_connect(self): with _starting_conns_lock: _starting_conns.discard(self) self._send_options_message() def handle_error(self): self.defunct(sys.exc_info()[1]) def handle_close(self): log.debug("connection closed by server") self.close() def handle_write(self): try: next_msg = self.deque.popleft() except IndexError: self._writable = False return try: sent = self.send(next_msg) except socket.error as err: if (err.args[0] in NONBLOCKING): self.deque.appendleft(next_msg) else: self.defunct(err) return else: if sent < len(next_msg): self.deque.appendleft(next_msg[sent:]) if not self.deque: self._writable = False self._readable = True def handle_read(self): try: buf = self.recv(self.in_buffer_size) except socket.error as err: if err.args[0] not in NONBLOCKING: self.defunct(err) return if buf: self._buf += buf while True: if len(self._buf) < 8: # we don't have a complete header yet break elif self._total_reqd_bytes and len(self._buf) < self._total_reqd_bytes: # we already saw a header, but we don't have a complete message yet break else: body_len = int32_unpack(self._buf[4:8]) if len(self._buf) - 8 >= body_len: msg = self._buf[:8 + body_len] self._buf = self._buf[8 + body_len:] self._total_reqd_bytes = 0 self.process_msg(msg, body_len) else: self._total_reqd_bytes = body_len + 8 if not self._callbacks: self._readable = False else: self.close() def handle_pushed(self, response): log.debug("Message pushed from server: %r", response) for cb in self._push_watchers.get(response.event_type, []): try: cb(response.event_args) except Exception: log.exception("Pushed event handler errored, ignoring:") def push(self, data): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] for i in xrange(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] with self.lock: self.deque.extend(chunks) self._writable = True def writable(self): return self._writable def readable(self): return self._readable or (self._have_listeners and not (self.is_defunct or self.is_closed)) def send_msg(self, msg, cb): if self.is_defunct: raise ConnectionShutdown("Connection to %s is defunct" % self.host) elif self.is_closed: raise ConnectionShutdown("Connection to %s is closed" % self.host) try: request_id = self._id_queue.get_nowait() except Queue.EMPTY: raise ConnectionBusy( "Connection to %s is at the max number of requests" % self.host) self._callbacks[request_id] = cb self.push(msg.to_string(request_id, compression=self.compressor)) return request_id def wait_for_response(self, msg): return self.wait_for_responses(msg)[0] def wait_for_responses(self, *msgs): waiter = ResponseWaiter(len(msgs)) for i, msg in enumerate(msgs): self.send_msg(msg, partial(waiter.got_response, index=i)) return waiter.deliver() def register_watcher(self, event_type, callback): self._push_watchers[event_type].add(callback) self._have_listeners = True self.wait_for_response(RegisterMessage(event_list=[event_type])) def register_watchers(self, type_callback_dict): for event_type, callback in type_callback_dict.items(): self._push_watchers[event_type].add(callback) self._have_listeners = True self.wait_for_response(RegisterMessage(event_list=type_callback_dict.keys()))