diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 56e34e97..8f419486 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -27,7 +27,6 @@ import socket import sys import time from threading import Lock, RLock, Thread, Event -import warnings import six from six.moves import range @@ -72,7 +71,6 @@ from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, BatchStatement, bind_params, QueryTrace, Statement, named_tuple_factory, dict_factory, FETCH_SIZE_UNSET) - def _is_eventlet_monkey_patched(): if 'eventlet.patcher' not in sys.modules: return False @@ -1267,7 +1265,8 @@ class Session(object): """ A default timeout, measured in seconds, for queries executed through :meth:`.execute()` or :meth:`.execute_async()`. This default may be - overridden with the `timeout` parameter for either of those methods. + overridden with the `timeout` parameter for either of those methods + or the `timeout` parameter for :meth:`.ResponseFuture.result()`. Setting this to :const:`None` will cause no timeouts to be set by default. @@ -1402,14 +1401,17 @@ class Session(object): trace details, the :attr:`~.Statement.trace` attribute will be left as :const:`None`. """ + if timeout is _NOT_SET: + timeout = self.default_timeout + if trace and not isinstance(query, Statement): raise TypeError( "The query argument must be an instance of a subclass of " "cassandra.query.Statement when trace=True") - future = self.execute_async(query, parameters, trace, timeout) + future = self.execute_async(query, parameters, trace) try: - result = future.result() + result = future.result(timeout) finally: if trace: try: @@ -1419,7 +1421,7 @@ class Session(object): return result - def execute_async(self, query, parameters=None, trace=False, timeout=_NOT_SET): + def execute_async(self, query, parameters=None, trace=False): """ Execute the given query and return a :class:`~.ResponseFuture` object which callbacks may be attached to for asynchronous response @@ -1456,14 +1458,11 @@ class Session(object): ... log.exception("Operation failed:") """ - if timeout is _NOT_SET: - timeout = self.default_timeout - - future = self._create_response_future(query, parameters, trace, timeout) + future = self._create_response_future(query, parameters, trace) future.send_request() return future - def _create_response_future(self, query, parameters, trace, timeout): + def _create_response_future(self, query, parameters, trace): """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None @@ -1514,7 +1513,7 @@ class Session(object): message.tracing = True return ResponseFuture( - self, message, query, timeout, metrics=self._metrics, + self, message, query, self.default_timeout, metrics=self._metrics, prepared_statement=prepared_statement) def prepare(self, query): @@ -1544,10 +1543,10 @@ class Session(object): Preparing the same query more than once will likely affect performance. """ message = PrepareMessage(query=query) - future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) + future = ResponseFuture(self, message, query=None) try: future.send_request() - query_id, column_metadata = future.result() + query_id, column_metadata = future.result(self.default_timeout) except Exception: log.exception("Error preparing query:") raise @@ -1572,7 +1571,7 @@ class Session(object): futures = [] for host in self._pools.keys(): if host != excluded_host and host.is_up: - future = ResponseFuture(self, PrepareMessage(query=query), None, self.default_timeout) + future = ResponseFuture(self, PrepareMessage(query=query), None) # we don't care about errors preparing against specific hosts, # since we can always prepare them as needed when the prepared @@ -1593,7 +1592,7 @@ class Session(object): for host, future in futures: try: - future.result() + future.result(self.default_timeout) except Exception: log.exception("Error preparing query for host %s:", host) @@ -2580,14 +2579,13 @@ class ResponseFuture(object): _start_time = None _metrics = None _paging_state = None - _timer = None - def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None): + def __init__(self, session, message, query, default_timeout=None, metrics=None, prepared_statement=None): self.session = session self.row_factory = session.row_factory self.message = message self.query = query - self.timeout = timeout + self.default_timeout = default_timeout self._metrics = metrics self.prepared_statement = prepared_statement self._callback_lock = Lock() @@ -2598,18 +2596,6 @@ class ResponseFuture(object): self._errors = {} self._callbacks = [] self._errbacks = [] - self._start_timer() - - def _start_timer(self): - if self.timeout is not None: - self._timer = self.session.cluster.connection_class.create_timer(self.timeout, self._on_timeout) - - def _cancel_timer(self): - if self._timer: - self._timer.cancel() - - def _on_timeout(self): - self._set_final_exception(OperationTimedOut(self._errors, self._current_host)) def _make_query_plan(self): # convert the list/generator/etc to an iterator so that subsequent @@ -2698,7 +2684,6 @@ class ResponseFuture(object): self._event.clear() self._final_result = _NOT_SET self._final_exception = None - self._start_timer() self.send_request() def _reprepare(self, prepare_message): @@ -2905,7 +2890,6 @@ class ResponseFuture(object): "statement on host %s: %s" % (self._current_host, response))) def _set_final_result(self, response): - self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) @@ -2920,7 +2904,6 @@ class ResponseFuture(object): fn(response, *args, **kwargs) def _set_final_exception(self, response): - self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) @@ -2964,11 +2947,6 @@ class ResponseFuture(object): encountered. If the final result or error has not been set yet, this method will block until that time. - .. versionchanged:: 2.6.0 - - **`timeout` is deprecated. Use timeout in the Session execute functions instead. - The following description applies to deprecated behavior:** - You may set a timeout (in seconds) with the `timeout` parameter. By default, the :attr:`~.default_timeout` for the :class:`.Session` this was created through will be used for the timeout on this @@ -2982,6 +2960,11 @@ class ResponseFuture(object): This is a client-side timeout. For more information about server-side coordinator timeouts, see :class:`.policies.RetryPolicy`. + **Important**: This timeout currently has no effect on callbacks registered + on a :class:`~.ResponseFuture` through :meth:`.ResponseFuture.add_callback` or + :meth:`.ResponseFuture.add_errback`; even if a query exceeds this default + timeout, neither the registered callback or errback will be called. + Example usage:: >>> future = session.execute_async("SELECT * FROM mycf") @@ -2995,24 +2978,27 @@ class ResponseFuture(object): ... log.exception("Operation failed:") """ - if timeout is not _NOT_SET: - msg = "ResponseFuture.result timeout argument is deprecated. Specify the request timeout via Session.execute[_async]." - warnings.warn(msg, DeprecationWarning) - log.warning(msg) - else: - timeout = None + if timeout is _NOT_SET: + timeout = self.default_timeout - self._event.wait(timeout) - # TODO: remove this conditional when deprecated timeout parameter is removed - if not self._event.is_set(): - self._on_timeout() if self._final_result is not _NOT_SET: if self._paging_state is None: return self._final_result else: - return PagedResult(self, self._final_result) - else: + return PagedResult(self, self._final_result, timeout) + elif self._final_exception: raise self._final_exception + else: + self._event.wait(timeout=timeout) + if self._final_result is not _NOT_SET: + if self._paging_state is None: + return self._final_result + else: + return PagedResult(self, self._final_result, timeout) + elif self._final_exception: + raise self._final_exception + else: + raise OperationTimedOut(errors=self._errors, last_host=self._current_host) def get_query_trace(self, max_wait=None): """ @@ -3162,9 +3148,10 @@ class PagedResult(object): response_future = None - def __init__(self, response_future, initial_response): + def __init__(self, response_future, initial_response, timeout=_NOT_SET): self.response_future = response_future self.current_response = iter(initial_response) + self.timeout = timeout def __iter__(self): return self @@ -3177,7 +3164,7 @@ class PagedResult(object): raise self.response_future.start_fetching_next_page() - result = self.response_future.result() + result = self.response_future.result(self.timeout) if self.response_future.has_more_pages: self.current_response = result.current_response else: diff --git a/cassandra/connection.py b/cassandra/connection.py index 0d0c024b..c239313c 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -16,12 +16,11 @@ from __future__ import absolute_import # to enable import io from stdlib from collections import defaultdict, deque import errno from functools import wraps, partial -from heapq import heappush, heappop import io import logging import os import sys -from threading import Thread, Event, RLock, Lock +from threading import Thread, Event, RLock import time if 'gevent.monkey' in sys.modules: @@ -39,8 +38,7 @@ from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessag QueryMessage, ResultMessage, decode_response, InvalidRequestException, SupportedMessage, AuthResponseMessage, AuthChallengeMessage, - AuthSuccessMessage, ProtocolException, - RegisterMessage) + AuthSuccessMessage, ProtocolException) from cassandra.util import OrderedDict @@ -194,7 +192,6 @@ class Connection(object): self.is_control_connection = is_control_connection self.user_type_map = user_type_map self._push_watchers = defaultdict(set) - self._callbacks = {} self._iobuf = io.BytesIO() if protocol_version >= 3: self._header_unpack = v3_header_unpack @@ -225,7 +222,6 @@ class Connection(object): self._full_header_length = self._header_length + 4 self.lock = RLock() - self.connected_event = Event() @classmethod def initialize_reactor(self): @@ -260,10 +256,6 @@ class Connection(object): else: return conn - @classmethod - def create_timer(cls, timeout, callback): - raise NotImplementedError() - def close(self): raise NotImplementedError() @@ -375,24 +367,11 @@ class Connection(object): self.defunct(exc) raise - def register_watcher(self, event_type, callback, register_timeout=None): - """ - Register a callback for a given event type. - """ - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=[event_type]), - timeout=register_timeout) + def register_watcher(self, event_type, callback): + raise NotImplementedError() - def register_watchers(self, type_callback_dict, register_timeout=None): - """ - Register multiple callback/event type pairs, expressed as a dict. - """ - for event_type, callback in type_callback_dict.items(): - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=type_callback_dict.keys()), - timeout=register_timeout) + def register_watchers(self, type_callback_dict): + raise NotImplementedError() def control_conn_disposed(self): self.is_control_connection = False @@ -892,76 +871,3 @@ class ConnectionHeartbeat(Thread): def _raise_if_stopped(self): if self._shutdown_event.is_set(): raise self.ShutdownException() - - -class Timer(object): - - canceled = False - - def __init__(self, timeout, callback): - self.end = time.time() + timeout - self.callback = callback - if timeout < 0: - self.callback() - - def cancel(self): - self.canceled = True - - def finish(self, time_now): - if self.canceled: - return True - - if time_now >= self.end: - self.callback() - return True - - return False - - -class TimerManager(object): - - def __init__(self): - self._queue = [] - self._new_timers = [] - - def add_timer(self, timer): - """ - called from client thread with a Timer object - """ - self._new_timers.append((timer.end, timer)) - - def service_timeouts(self): - """ - run callbacks on all expired timers - Called from the event thread - :return: next end time, or None - """ - queue = self._queue - new_timers = self._new_timers - while self._new_timers: - heappush(queue, new_timers.pop()) - now = time.time() - while queue: - try: - timer = queue[0][1] - if timer.finish(now): - heappop(queue) - else: - return timer.end - except Exception: - log.exception("Exception while servicing timeout callback: ") - - @property - def next_timeout(self): - try: - return self._queue[0][0] - except IndexError: - pass - - @property - def next_offset(self): - try: - next_end = self._queue[0][0] - return next_end - time.time() - except IndexError: - pass diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index b815d20a..ef687c38 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -19,12 +19,11 @@ import os import socket import sys from threading import Event, Lock, Thread -import time import weakref from six.moves import range -from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN +from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode try: from weakref import WeakSet except ImportError: @@ -37,9 +36,10 @@ try: except ImportError: ssl = None # NOQA +from cassandra import OperationTimedOut from cassandra.connection import (Connection, ConnectionShutdown, - ConnectionException, NONBLOCKING, - Timer, TimerManager) + ConnectionException, NONBLOCKING) +from cassandra.protocol import RegisterMessage log = logging.getLogger(__name__) @@ -55,17 +55,15 @@ def _cleanup(loop_weakref): class AsyncoreLoop(object): - def __init__(self): self._pid = os.getpid() self._loop_lock = Lock() self._started = False self._shutdown = False + self._conns_lock = Lock() + self._conns = WeakSet() self._thread = None - - self._timers = TimerManager() - atexit.register(partial(_cleanup, weakref.ref(self))) def maybe_start(self): @@ -88,22 +86,24 @@ class AsyncoreLoop(object): def _run_loop(self): log.debug("Starting asyncore event loop") with self._loop_lock: - while not self._shutdown: + while True: try: - asyncore.loop(timeout=0.001, use_poll=True, count=100) - self._timers.service_timeouts() - if not asyncore.socket_map: - time.sleep(0.005) + asyncore.loop(timeout=0.001, use_poll=True, count=1000) except Exception: log.debug("Asyncore event loop stopped unexepectedly", exc_info=True) break + + if self._shutdown: + break + + with self._conns_lock: + if len(self._conns) == 0: + break + self._started = False log.debug("Asyncore event loop ended") - def add_timer(self, timer): - self._timers.add_timer(timer) - def _cleanup(self): self._shutdown = True if not self._thread: @@ -118,6 +118,14 @@ class AsyncoreLoop(object): log.debug("Event loop thread was joined") + def connection_created(self, connection): + with self._conns_lock: + self._conns.add(connection) + + def connection_destroyed(self, connection): + with self._conns_lock: + self._conns.discard(connection) + class AsyncoreConnection(Connection, asyncore.dispatcher): """ @@ -148,19 +156,18 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): cls._loop._cleanup() cls._loop = None - @classmethod - def create_timer(cls, timeout, callback): - timer = Timer(timeout, callback) - cls._loop.add_timer(timer) - return timer - def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) asyncore.dispatcher.__init__(self) + self.connected_event = Event() + + self._callbacks = {} self.deque = deque() self.deque_lock = Lock() + self._loop.connection_created(self) + sockerr = None addresses = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) for (af, socktype, proto, canonname, sockaddr) in addresses: @@ -233,6 +240,8 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): asyncore.dispatcher.close(self) log.debug("Closed socket to %s", self.host) + self._loop.connection_destroyed(self) + if not self.is_defunct: self.error_all_callbacks( ConnectionShutdown("Connection to %s was closed" % self.host)) @@ -314,3 +323,14 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): def readable(self): return self._readable or (self.is_control_connection and not (self.is_defunct or self.is_closed)) + + def register_watcher(self, event_type, callback, register_timeout=None): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), timeout=register_timeout) diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py index aceac55e..670d0f18 100644 --- a/cassandra/io/eventletreactor.py +++ b/cassandra/io/eventletreactor.py @@ -16,6 +16,7 @@ # Originally derived from MagnetoDB source: # https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py +from collections import defaultdict from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL import eventlet from eventlet.green import select, socket @@ -24,11 +25,12 @@ from functools import partial import logging import os from threading import Event -import time from six.moves import xrange -from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager +from cassandra import OperationTimedOut +from cassandra.connection import Connection, ConnectionShutdown +from cassandra.protocol import RegisterMessage log = logging.getLogger(__name__) @@ -51,45 +53,19 @@ class EventletConnection(Connection): _write_watcher = None _socket = None - _timers = None - _timeout_watcher = None - _new_timer = None - @classmethod def initialize_reactor(cls): eventlet.monkey_patch() - if not cls._timers: - cls._timers = TimerManager() - cls._timeout_watcher = eventlet.spawn(cls.service_timeouts) - cls._new_timer = Event() - - @classmethod - def create_timer(cls, timeout, callback): - timer = Timer(timeout, callback) - cls._timers.add_timer(timer) - cls._new_timer.set() - return timer - - @classmethod - def service_timeouts(cls): - """ - cls._timeout_watcher runs in this loop forever. - It is usually waiting for the next timeout on the cls._new_timer Event. - When new timers are added, that event is set so that the watcher can - wake up and possibly set an earlier timeout. - """ - timer_manager = cls._timers - while True: - next_end = timer_manager.service_timeouts() - sleep_time = max(next_end - time.time(), 0) if next_end else 10000 - cls._new_timer.wait(sleep_time) - cls._new_timer.clear() def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) + self.connected_event = Event() self._write_queue = Queue() + self._callbacks = {} + self._push_watchers = defaultdict(set) + sockerr = None addresses = socket.getaddrinfo( self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM @@ -189,3 +165,16 @@ class EventletConnection(Connection): chunk_size = self.out_buffer_size for i in xrange(0, len(data), chunk_size): self._write_queue.put(data[i:i + chunk_size]) + + def register_watcher(self, event_type, callback, register_timeout=None): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), + timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), + timeout=register_timeout) diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py index 84285957..6e9af0da 100644 --- a/cassandra/io/geventreactor.py +++ b/cassandra/io/geventreactor.py @@ -13,20 +13,21 @@ # limitations under the License. import gevent from gevent import select, socket, ssl -import gevent.event +from gevent.event import Event from gevent.queue import Queue from collections import defaultdict from functools import partial import logging import os -import time -from six.moves import range +from six.moves import xrange from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL -from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager +from cassandra import OperationTimedOut +from cassandra.connection import Connection, ConnectionShutdown +from cassandra.protocol import RegisterMessage log = logging.getLogger(__name__) @@ -49,39 +50,15 @@ class GeventConnection(Connection): _write_watcher = None _socket = None - _timers = None - _timeout_watcher = None - _new_timer = None - - @classmethod - def initialize_reactor(cls): - if not cls._timers: - cls._timers = TimerManager() - cls._timeout_watcher = gevent.spawn(cls.service_timeouts) - cls._new_timer = gevent.event.Event() - - @classmethod - def create_timer(cls, timeout, callback): - timer = Timer(timeout, callback) - cls._timers.add_timer(timer) - cls._new_timer.set() - return timer - - @classmethod - def service_timeouts(cls): - timer_manager = cls._timers - timer_event = cls._new_timer - while True: - next_end = timer_manager.service_timeouts() - sleep_time = max(next_end - time.time(), 0) if next_end else 10000 - timer_event.wait(sleep_time) - timer_event.clear() - def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) + self.connected_event = Event() self._write_queue = Queue() + self._callbacks = {} + self._push_watchers = defaultdict(set) + sockerr = None addresses = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) for (af, socktype, proto, canonname, sockaddr) in addresses: @@ -182,5 +159,18 @@ class GeventConnection(Connection): def push(self, data): chunk_size = self.out_buffer_size - for i in range(0, len(data), chunk_size): + for i in xrange(0, len(data), chunk_size): self._write_queue.put(data[i:i + chunk_size]) + + def register_watcher(self, event_type, callback, register_timeout=None): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), + timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), + timeout=register_timeout) diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index e6abb76b..93b4c978 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -20,10 +20,11 @@ import socket from threading import Event, Lock, Thread import weakref -from six.moves import range +from six.moves import xrange -from cassandra.connection import (Connection, ConnectionShutdown, - NONBLOCKING, Timer, TimerManager) +from cassandra import OperationTimedOut +from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING +from cassandra.protocol import RegisterMessage try: import cassandra.io.libevwrapper as libev except ImportError: @@ -39,7 +40,7 @@ except ImportError: try: import ssl except ImportError: - ssl = None # NOQA + ssl = None # NOQA log = logging.getLogger(__name__) @@ -49,6 +50,7 @@ def _cleanup(loop_weakref): loop = loop_weakref() except ReferenceError: return + loop._cleanup() @@ -83,11 +85,11 @@ class LibevLoop(object): self._loop.unref() self._preparer.start() - self._timers = TimerManager() - self._loop_timer = libev.Timer(self._loop, self._on_loop_timer) - atexit.register(partial(_cleanup, weakref.ref(self))) + def notify(self): + self._notifier.send() + def maybe_start(self): should_start = False with self._lock: @@ -131,7 +133,6 @@ class LibevLoop(object): conn._read_watcher.stop() del conn._read_watcher - self.notify() # wake the timer watcher log.debug("Waiting for event loop thread to join...") self._thread.join(timeout=1.0) if self._thread.is_alive(): @@ -142,24 +143,6 @@ class LibevLoop(object): log.debug("Event loop thread was joined") self._loop = None - def add_timer(self, timer): - self._timers.add_timer(timer) - self._notifier.send() # wake up in case this timer is earlier - - def _update_timer(self): - if not self._shutdown: - self._timers.service_timeouts() - offset = self._timers.next_offset or 100000 # none pending; will be updated again when something new happens - self._loop_timer.start(offset) - else: - self._loop_timer.stop() - - def _on_loop_timer(self): - self._timers.service_timeouts() - - def notify(self): - self._notifier.send() - def connection_created(self, conn): with self._conn_set_lock: new_live_conns = self._live_conns.copy() @@ -222,9 +205,6 @@ class LibevLoop(object): changed = True - # TODO: update to do connection management, timer updates through dedicaterd async 'notifier' callbacks - self._update_timer() - if changed: self._notifier.send() @@ -256,15 +236,12 @@ class LibevConnection(Connection): cls._libevloop._cleanup() cls._libevloop = None - @classmethod - def create_timer(cls, timeout, callback): - timer = Timer(timeout, callback) - cls._libevloop.add_timer(timer) - return timer - def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) + self.connected_event = Event() + + self._callbacks = {} self.deque = deque() self._deque_lock = Lock() @@ -384,7 +361,7 @@ class LibevConnection(Connection): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] - for i in range(0, len(data), sabs): + for i in xrange(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] @@ -392,3 +369,14 @@ class LibevConnection(Connection): with self._deque_lock: self.deque.extend(chunks) self._libevloop.notify() + + def register_watcher(self, event_type, callback, register_timeout=None): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), timeout=register_timeout) diff --git a/cassandra/io/libevwrapper.c b/cassandra/io/libevwrapper.c index 99e1df30..cbac83b2 100644 --- a/cassandra/io/libevwrapper.c +++ b/cassandra/io/libevwrapper.c @@ -451,131 +451,6 @@ static PyTypeObject libevwrapper_PrepareType = { (initproc)Prepare_init, /* tp_init */ }; -typedef struct libevwrapper_Timer { - PyObject_HEAD - struct ev_timer timer; - struct libevwrapper_Loop *loop; - PyObject *callback; -} libevwrapper_Timer; - -static void -Timer_dealloc(libevwrapper_Timer *self) { - Py_XDECREF(self->loop); - Py_XDECREF(self->callback); - Py_TYPE(self)->tp_free((PyObject *)self); -} - -static void timer_callback(struct ev_loop *loop, ev_timer *watcher, int revents) { - libevwrapper_Timer *self = watcher->data; - - PyObject *result = NULL; - PyGILState_STATE gstate; - - gstate = PyGILState_Ensure(); - result = PyObject_CallFunction(self->callback, NULL); - if (!result) { - PyErr_WriteUnraisable(self->callback); - } - Py_XDECREF(result); - - PyGILState_Release(gstate); -} - -static int -Timer_init(libevwrapper_Timer *self, PyObject *args, PyObject *kwds) { - PyObject *callback; - PyObject *loop; - - if (!PyArg_ParseTuple(args, "OO", &loop, &callback)) { - return -1; - } - - if (loop) { - Py_INCREF(loop); - self->loop = (libevwrapper_Loop *)loop; - } else { - return -1; - } - - if (callback) { - if (!PyCallable_Check(callback)) { - PyErr_SetString(PyExc_TypeError, "callback parameter must be callable"); - Py_XDECREF(loop); - return -1; - } - Py_INCREF(callback); - self->callback = callback; - } - ev_init(&self->timer, timer_callback); - self->timer.data = self; - return 0; -} - -static PyObject * -Timer_start(libevwrapper_Timer *self, PyObject *args) { - double timeout; - if (!PyArg_ParseTuple(args, "d", &timeout)) { - return NULL; - } - /* some tiny non-zero number to avoid zero, and - make it run immediately for negative timeouts */ - self->timer.repeat = fmax(timeout, 0.000000001); - ev_timer_again(self->loop->loop, &self->timer); - Py_RETURN_NONE; -} - -static PyObject * -Timer_stop(libevwrapper_Timer *self, PyObject *args) { - ev_timer_stop(self->loop->loop, &self->timer); - Py_RETURN_NONE; -} - -static PyMethodDef Timer_methods[] = { - {"start", (PyCFunction)Timer_start, METH_VARARGS, "Start the Timer watcher"}, - {"stop", (PyCFunction)Timer_stop, METH_NOARGS, "Stop the Timer watcher"}, - {NULL} /* Sentinal */ -}; - -static PyTypeObject libevwrapper_TimerType = { - PyVarObject_HEAD_INIT(NULL, 0) - "cassandra.io.libevwrapper.Timer", /*tp_name*/ - sizeof(libevwrapper_Timer), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - (destructor)Timer_dealloc, /*tp_dealloc*/ - 0, /*tp_print*/ - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - 0, /*tp_compare*/ - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - 0, /*tp_as_sequence*/ - 0, /*tp_as_mapping*/ - 0, /*tp_hash */ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ - "Timer objects", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - Timer_methods, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)Timer_init, /* tp_init */ -}; - - static PyMethodDef module_methods[] = { {NULL} /* Sentinal */ }; @@ -625,10 +500,6 @@ initlibevwrapper(void) if (PyType_Ready(&libevwrapper_AsyncType) < 0) INITERROR; - libevwrapper_TimerType.tp_new = PyType_GenericNew; - if (PyType_Ready(&libevwrapper_TimerType) < 0) - INITERROR; - # if PY_MAJOR_VERSION >= 3 module = PyModule_Create(&moduledef); # else @@ -661,10 +532,6 @@ initlibevwrapper(void) if (PyModule_AddObject(module, "Async", (PyObject *)&libevwrapper_AsyncType) == -1) INITERROR; - Py_INCREF(&libevwrapper_TimerType); - if (PyModule_AddObject(module, "Timer", (PyObject *)&libevwrapper_TimerType) == -1) - INITERROR; - if (!PyEval_ThreadsInitialized()) { PyEval_InitThreads(); } diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py index 0f5c841c..ff81e561 100644 --- a/cassandra/io/twistedreactor.py +++ b/cassandra/io/twistedreactor.py @@ -15,15 +15,16 @@ Module that implements an event loop based on twisted ( https://twistedmatrix.com ). """ -import atexit +from twisted.internet import reactor, protocol +from threading import Event, Thread, Lock from functools import partial import logging -from threading import Event, Thread, Lock -import time -from twisted.internet import reactor, protocol import weakref +import atexit -from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager +from cassandra import OperationTimedOut +from cassandra.connection import Connection, ConnectionShutdown +from cassandra.protocol import RegisterMessage log = logging.getLogger(__name__) @@ -108,12 +109,9 @@ class TwistedLoop(object): _lock = None _thread = None - _timeout_task = None - _timeout = None def __init__(self): self._lock = Lock() - self._timers = TimerManager() def maybe_start(self): with self._lock: @@ -135,27 +133,6 @@ class TwistedLoop(object): "Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") - def add_timer(self, timer): - self._timers.add_timer(timer) - # callFromThread to schedule from the loop thread, where - # the timeout task can safely be modified - reactor.callFromThread(self._schedule_timeout, timer.end) - - def _schedule_timeout(self, next_timeout): - if next_timeout: - delay = max(next_timeout - time.time(), 0) - if self._timeout_task and self._timeout_task.active(): - if next_timeout < self._timeout: - self._timeout_task.reset(delay) - self._timeout = next_timeout - else: - self._timeout_task = reactor.callLater(delay, self._on_loop_timer) - self._timeout = next_timeout - - def _on_loop_timer(self): - self._timers.service_timeouts() - self._schedule_timeout(self._timers.next_timeout) - class TwistedConnection(Connection): """ @@ -171,12 +148,6 @@ class TwistedConnection(Connection): if not cls._loop: cls._loop = TwistedLoop() - @classmethod - def create_timer(cls, timeout, callback): - timer = Timer(timeout, callback) - cls._loop.add_timer(timer) - return timer - def __init__(self, *args, **kwargs): """ Initialization method. @@ -188,9 +159,11 @@ class TwistedConnection(Connection): """ Connection.__init__(self, *args, **kwargs) + self.connected_event = Event() self.is_closed = True self.connector = None + self._callbacks = {} reactor.callFromThread(self.add_connection) self._loop.maybe_start() @@ -247,3 +220,22 @@ class TwistedConnection(Connection): the event loop when it gets the chance. """ reactor.callFromThread(self.connector.transport.write, data) + + def register_watcher(self, event_type, callback, register_timeout=None): + """ + Register a callback for a given event type. + """ + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), + timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + """ + Register multiple callback/event type pairs, expressed as a dict. + """ + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), + timeout=register_timeout) diff --git a/cassandra/query.py b/cassandra/query.py index 30bf4501..1232514b 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -844,14 +844,14 @@ class QueryTrace(object): break def _execute(self, query, parameters, time_spent, max_wait): - timeout = (max_wait - time_spent) if max_wait is not None else None - future = self._session._create_response_future(query, parameters, trace=False, timeout=timeout) # in case the user switched the row factory, set it to namedtuple for this query + future = self._session._create_response_future(query, parameters, trace=False) future.row_factory = named_tuple_factory future.send_request() + timeout = (max_wait - time_spent) if max_wait is not None else None try: - return future.result() + return future.result(timeout=timeout) except OperationTimedOut: raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) diff --git a/docs/object_mapper.rst b/docs/object_mapper.rst index 4e389940..26d78a09 100644 --- a/docs/object_mapper.rst +++ b/docs/object_mapper.rst @@ -48,7 +48,7 @@ Getting Started from cassandra.cqlengine import columns from cassandra.cqlengine import connection from datetime import datetime - from cassandra.cqlengine.management import sync_table + from cassandra.cqlengine.management import create_keyspace, sync_table from cassandra.cqlengine.models import Model #first, define a model diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index a779e133..7fc4ed4e 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -261,7 +261,10 @@ class ConnectionTest(unittest.TestCase): Ensure the following methods throw NIE's. If not, come back and test them. """ c = self.make_connection() + self.assertRaises(NotImplementedError, c.close) + self.assertRaises(NotImplementedError, c.register_watcher, None, None) + self.assertRaises(NotImplementedError, c.register_watchers, None) def test_set_keyspace_blocking(self): c = self.make_connection() diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 986945ba..027fe732 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -47,7 +47,7 @@ class ResponseFutureTests(unittest.TestCase): def make_response_future(self, session): query = SimpleStatement("SELECT * FROM foo") message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - return ResponseFuture(session, message, query, 1) + return ResponseFuture(session, message, query) def make_mock_response(self, results): return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=results, paging_state=None) @@ -122,7 +122,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_read_timeout.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query, 1) + rf = ResponseFuture(session, message, query) rf.send_request() result = Mock(spec=ReadTimeoutErrorMessage, info={}) @@ -137,7 +137,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_write_timeout.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query, 1) + rf = ResponseFuture(session, message, query) rf.send_request() result = Mock(spec=WriteTimeoutErrorMessage, info={}) @@ -151,7 +151,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query, 1) + rf = ResponseFuture(session, message, query) rf.send_request() result = Mock(spec=UnavailableErrorMessage, info={}) @@ -165,7 +165,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_unavailable.return_value = (RetryPolicy.IGNORE, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query, 1) + rf = ResponseFuture(session, message, query) rf.send_request() result = Mock(spec=UnavailableErrorMessage, info={}) @@ -184,7 +184,7 @@ class ResponseFutureTests(unittest.TestCase): connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) - rf = ResponseFuture(session, message, query, 1) + rf = ResponseFuture(session, message, query) rf.send_request() rf.session._pools.get.assert_called_once_with('ip1') @@ -279,7 +279,7 @@ class ResponseFutureTests(unittest.TestCase): session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2'] session._pools.get.return_value.is_shutdown = True - rf = ResponseFuture(session, Mock(), Mock(), 1) + rf = ResponseFuture(session, Mock(), Mock()) rf.send_request() self.assertRaises(NoHostAvailable, rf.result) @@ -354,7 +354,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query, 1) + rf = ResponseFuture(session, message, query) rf.send_request() rf.add_errback(self.assertIsInstance, Exception) @@ -401,7 +401,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query, 1) + rf = ResponseFuture(session, message, query) rf.send_request() callback = Mock() @@ -431,7 +431,7 @@ class ResponseFutureTests(unittest.TestCase): message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) # test errback - rf = ResponseFuture(session, message, query, 1) + rf = ResponseFuture(session, message, query) rf.send_request() rf.add_callbacks( @@ -443,7 +443,7 @@ class ResponseFutureTests(unittest.TestCase): self.assertRaises(Exception, rf.result) # test callback - rf = ResponseFuture(session, message, query, 1) + rf = ResponseFuture(session, message, query) rf.send_request() callback = Mock()