diff --git a/cassandra/connection.py b/cassandra/connection.py index 027259ea..22d4fd01 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -357,13 +357,16 @@ class Connection(object): class ResponseWaiter(object): - def __init__(self, num_responses): + def __init__(self, connection, num_responses): + self.connection = connection self.pending = num_responses self.error = None self.responses = [None] * num_responses self.event = Event() def got_response(self, response, index): + with self.connection.lock: + self.connection.in_flight -= 1 if isinstance(response, Exception): self.error = response self.event.set() diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index 706e7b7c..74373576 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -347,7 +347,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): def wait_for_responses(self, *msgs, **kwargs): timeout = kwargs.get('timeout') - waiter = ResponseWaiter(len(msgs)) + waiter = ResponseWaiter(self, len(msgs)) # busy wait for sufficient space on the connection messages_sent = 0 @@ -370,11 +370,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): raise OperationTimedOut() time.sleep(0.01) - try: - return waiter.deliver(timeout) - finally: - with self.lock: - self.in_flight -= len(msgs) + return waiter.deliver(timeout) def register_watcher(self, event_type, callback): self._push_watchers[event_type].add(callback) diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index e5dbcd7b..10475d1f 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -395,7 +395,7 @@ class LibevConnection(Connection): def wait_for_responses(self, *msgs, **kwargs): timeout = kwargs.get('timeout') - waiter = ResponseWaiter(len(msgs)) + waiter = ResponseWaiter(self, len(msgs)) # busy wait for sufficient space on the connection messages_sent = 0 @@ -418,11 +418,7 @@ class LibevConnection(Connection): raise OperationTimedOut() time.sleep(0.01) - try: - return waiter.deliver(timeout) - finally: - with self.lock: - self.in_flight -= len(msgs) + return waiter.deliver(timeout) def register_watcher(self, event_type, callback): self._push_watchers[event_type].add(callback)