diff --git a/cassandra/cluster.py b/cassandra/cluster.py index d4ac1e5f..8924c72b 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -151,6 +151,8 @@ class ResponseFuture(object): self._set_final_exception(response) else: # IGNORE self._set_final_result(None) + elif isinstance(response, Exception): + self._set_final_exception(response) else: # we got some other kind of response message msg = "Got unexpected message: %r" % (response,) diff --git a/cassandra/connection.py b/cassandra/connection.py index da266a52..6cd1d318 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1,13 +1,14 @@ -import pyev -import errno from collections import defaultdict, deque -from functools import partial +import errno +from functools import partial, wraps import logging import socket from threading import RLock, Event, Lock, Thread import traceback from Queue import Queue +import pyev + from cassandra import ConsistencyLevel from cassandra.marshal import (int8_unpack, int32_unpack) from cassandra.decoder import (OptionsMessage, ReadyMessage, AuthenticateMessage, @@ -49,7 +50,7 @@ class ConnectionException(Exception): self.host = host -class ConnectionBusy(ConnectionException): +class ConnectionBusy(Exception): pass @@ -57,10 +58,6 @@ class ProgrammingError(Exception): pass -class InternalError(Exception): - pass - - class ProtocolError(Exception): pass @@ -108,6 +105,18 @@ def _start_loop(): return should_start +def defunct_on_error(f): + + @wraps(f) + def wrapper(self, *args, **kwargs): + try: + return f(self, *args, **kwargs) + except Exception, exc: + self.defunct(exc) + + return wrapper + + class Connection(object): in_buffer_size = 4096 @@ -198,11 +207,26 @@ class Connection(object): with _loop_lock: _loop_notifier.send() + # don't leave in-progress operations hanging + if not self.is_defunct: + self._error_all_callbacks( + ConnectionException("Connection to %s was closed" % self.host)) + def __del__(self): self.close() def defunct(self, exc): - pass + log.debug("Defuncting connection to %s: %s %s" % + (self.host, exc, traceback.format_exc(exc))) + self.last_error = exc + self.is_defunct = True + self._error_all_callbacks(exc) + self.connected_event.set() + return exc + + def _error_all_callbacks(self, exc): + for cb in self._callbacks.values(): + cb(exc) def handle_write(self, watcher, revents): try: @@ -217,7 +241,7 @@ class Connection(object): if (err.args[0] in NONBLOCKING): self.deque.appendleft(next_msg) else: - self.handle_error() + self.defunct(err) return else: if sent < len(next_msg): @@ -231,7 +255,7 @@ class Connection(object): buf = self._socket.recv(self.in_buffer_size) except socket.error, err: if err.args[0] not in NONBLOCKING: - self.handle_error() + self.defunct(err) return if buf: @@ -256,6 +280,7 @@ class Connection(object): logging.debug("connection closed by server") self.close() + @defunct_on_error def process_msg(self, msg, body_len): version, flags, stream_id, opcode = map(int8_unpack, msg[:4]) if stream_id < 0: @@ -291,16 +316,12 @@ class Connection(object): except: log.exception("Callback handler errored, ignoring:") - def handle_error(self): - log.error(traceback.format_exc()) - self.is_defunct = True - def handle_pushed(self, response): for cb in self._push_watchers[response.type]: try: cb(response) except: - log.error("Pushed event handler errored, ignoring: %s" % traceback.format_exc()) + log.exception("Pushed event handler errored, ignoring:") def push(self, data): sabs = self.out_buffer_size @@ -317,6 +338,11 @@ class Connection(object): _loop_notifier.send() def send_msg(self, msg, cb): + if self.is_defunct: + raise ConnectionException("Connection to %s is defunct" % self.host) + elif self.is_closed: + raise ConnectionException("Connection to %s is closed" % self.host) + try: request_id = self._id_queue.get_nowait() except Queue.EMPTY: @@ -331,8 +357,8 @@ class Connection(object): waiter = ResponseWaiter(len(msgs)) for i, msg in enumerate(msgs): self.send_msg(msg, partial(waiter.got_response, index=i)) - waiter.event.wait() - return waiter.responses + + return waiter.deliver() def register_watcher(self, event_type, callback): self._push_watchers[event_type].add(callback) @@ -341,6 +367,7 @@ class Connection(object): for event_type, callback in type_callback_dict.items(): self.register_watcher(event_type, callback) + @defunct_on_error def _handle_options_response(self, options_response): log.debug("Received options response on new Connection from %s" % self.host) self.supported_cql_versions = options_response.cql_versions @@ -376,6 +403,7 @@ class Connection(object): sm = StartupMessage(cqlversion=self.cql_version, options=opts) self.send_msg(sm, cb=self._handle_startup_response) + @defunct_on_error def _handle_startup_response(self, startup_response): if isinstance(startup_response, ReadyMessage): log.debug("Got ReadyMessage on new Connection from %s" % self.host) @@ -386,26 +414,20 @@ class Connection(object): log.debug("Got AuthenticateMessage on new Connection from %s" % self.host) if self.credentials is None: - self.last_error = ProgrammingError( - 'Remote end requires authentication.') - self.connected_event.set() - return + raise ProgrammingError('Remote end requires authentication.') self.authenticator = startup_response.authenticator cm = CredentialsMessage(creds=self.credentials) self.send_msg(cm, cb=self._handle_startup_response) elif isinstance(startup_response, ErrorMessage): log.debug("Received ErrorMessage on new Connection from %s" % self.host) - self.last_error = ProgrammingError( + raise ProgrammingError( "Server did not accept credentials. %s" % startup_response.summary_msg()) - self.connected_event.set() else: - log.error("Unexpected response during Connection setup") - self.last_error = InternalError( - "Unexpected response %r during connection setup" - % (startup_response,)) - self.connected_event.set() + msg = "Unexpected response during Connection setup: %r" % (startup_response,) + log.error(msg) + raise ProtocolError(msg) def set_keyspace(self, keyspace): if not keyspace: @@ -422,21 +444,34 @@ class Connection(object): if isinstance(result, ResultMessage): self.keyspace = keyspace else: - self.defunct(ConnectionException( + raise self.defunct(ConnectionException( "Problem while setting keyspace: %r" % (result,), self.host)) except Exception, exc: - self.defunct(ConnectionException( + raise self.defunct(ConnectionException( "Problem while setting keyspace: %r" % (exc,), self.host)) + class ResponseWaiter(object): def __init__(self, num_responses): self.pending = num_responses + self.error = None self.responses = [None] * num_responses self.event = Event() def got_response(self, response, index): - self.responses[index] = response - self.pending -= 1 - if not self.pending: + if isinstance(response, Exception): + self.error = response self.event.set() + else: + self.responses[index] = response + self.pending -= 1 + if not self.pending: + self.event.set() + + def deliver(self): + self.event.wait() + if self.error: + raise self.error + else: + return self.responses