diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 828fa826..45784c46 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1442,8 +1442,7 @@ class Session(object): """ A default timeout, measured in seconds, for queries executed through :meth:`.execute()` or :meth:`.execute_async()`. This default may be - overridden with the `timeout` parameter for either of those methods - or the `timeout` parameter for :meth:`.ResponseFuture.result()`. + overridden with the `timeout` parameter for either of those methods. Setting this to :const:`None` will cause no timeouts to be set by default. @@ -1581,17 +1580,14 @@ class Session(object): If `query` is a Statement with its own custom_payload. The message payload will be a union of the two, with the values specified here taking precedence. """ - if timeout is _NOT_SET: - timeout = self.default_timeout - if trace and not isinstance(query, Statement): raise TypeError( "The query argument must be an instance of a subclass of " "cassandra.query.Statement when trace=True") - future = self.execute_async(query, parameters, trace, custom_payload) + future = self.execute_async(query, parameters, trace, custom_payload, timeout) try: - result = future.result(timeout) + result = future.result() finally: if trace: try: @@ -1601,7 +1597,7 @@ class Session(object): return result - def execute_async(self, query, parameters=None, trace=False, custom_payload=None): + def execute_async(self, query, parameters=None, trace=False, custom_payload=None, timeout=_NOT_SET): """ Execute the given query and return a :class:`~.ResponseFuture` object which callbacks may be attached to for asynchronous response @@ -1646,11 +1642,14 @@ class Session(object): ... log.exception("Operation failed:") """ - future = self._create_response_future(query, parameters, trace, custom_payload) + if timeout is _NOT_SET: + timeout = self.default_timeout + + future = self._create_response_future(query, parameters, trace, custom_payload, timeout) future.send_request() return future - def _create_response_future(self, query, parameters, trace, custom_payload): + def _create_response_future(self, query, parameters, trace, custom_payload, timeout): """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None @@ -1704,7 +1703,7 @@ class Session(object): message.update_custom_payload(custom_payload) return ResponseFuture( - self, message, query, self.default_timeout, metrics=self._metrics, + self, message, query, timeout, metrics=self._metrics, prepared_statement=prepared_statement) def prepare(self, query, custom_payload=None): @@ -1737,11 +1736,10 @@ class Session(object): message. See :ref:`custom_payload`. """ message = PrepareMessage(query=query) - message.custom_payload = custom_payload - future = ResponseFuture(self, message, query=None) + future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) try: future.send_request() - query_id, column_metadata, pk_indexes = future.result(self.default_timeout) + query_id, column_metadata, pk_indexes = future.result() except Exception: log.exception("Error preparing query:") raise @@ -1767,7 +1765,7 @@ class Session(object): futures = [] for host in self._pools.keys(): if host != excluded_host and host.is_up: - future = ResponseFuture(self, PrepareMessage(query=query), None) + future = ResponseFuture(self, PrepareMessage(query=query), None, self.default_timeout) # we don't care about errors preparing against specific hosts, # since we can always prepare them as needed when the prepared @@ -1788,7 +1786,7 @@ class Session(object): for host, future in futures: try: - future.result(self.default_timeout) + future.result() except Exception: log.exception("Error preparing query for host %s:", host) @@ -2832,13 +2830,14 @@ class ResponseFuture(object): _paging_state = None _custom_payload = None _warnings = None + _timer = None - def __init__(self, session, message, query, default_timeout=None, metrics=None, prepared_statement=None): + def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None): self.session = session self.row_factory = session.row_factory self.message = message self.query = query - self.default_timeout = default_timeout + self.timeout = timeout self._metrics = metrics self.prepared_statement = prepared_statement self._callback_lock = Lock() @@ -2849,6 +2848,18 @@ class ResponseFuture(object): self._errors = {} self._callbacks = [] self._errbacks = [] + self._start_timer() + + def _start_timer(self): + if self.timeout is not None: + self._timer = self.session.cluster.connection_class.create_timer(self.timeout, self._on_timeout) + + def _cancel_timer(self): + if self._timer: + self._timer.cancel() + + def _on_timeout(self): + self._set_final_exception(OperationTimedOut(self._errors, self._current_host)) def _make_query_plan(self): # convert the list/generator/etc to an iterator so that subsequent @@ -2973,6 +2984,7 @@ class ResponseFuture(object): self._event.clear() self._final_result = _NOT_SET self._final_exception = None + self._start_timer() self.send_request() def _reprepare(self, prepare_message): @@ -3187,6 +3199,7 @@ class ResponseFuture(object): "statement on host %s: %s" % (self._current_host, response))) def _set_final_result(self, response): + self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) @@ -3201,6 +3214,7 @@ class ResponseFuture(object): fn(response, *args, **kwargs) def _set_final_exception(self, response): + self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) @@ -3244,6 +3258,11 @@ class ResponseFuture(object): encountered. If the final result or error has not been set yet, this method will block until that time. + .. versionchanged:: 2.6.0 + + **`timeout` is deprecated. Use timeout in the Session execute functions instead. + The following description applies to deprecated behavior:** + You may set a timeout (in seconds) with the `timeout` parameter. By default, the :attr:`~.default_timeout` for the :class:`.Session` this was created through will be used for the timeout on this @@ -3257,11 +3276,6 @@ class ResponseFuture(object): This is a client-side timeout. For more information about server-side coordinator timeouts, see :class:`.policies.RetryPolicy`. - **Important**: This timeout currently has no effect on callbacks registered - on a :class:`~.ResponseFuture` through :meth:`.ResponseFuture.add_callback` or - :meth:`.ResponseFuture.add_errback`; even if a query exceeds this default - timeout, neither the registered callback or errback will be called. - Example usage:: >>> future = session.execute_async("SELECT * FROM mycf") @@ -3275,27 +3289,24 @@ class ResponseFuture(object): ... log.exception("Operation failed:") """ - if timeout is _NOT_SET: - timeout = self.default_timeout + if timeout is not _NOT_SET: + msg = "ResponseFuture.result timeout argument is deprecated. Specify the request timeout via Session.execute[_async]." + warnings.warn(msg, DeprecationWarning) + log.warning(msg) + else: + timeout = None + self._event.wait(timeout) + # TODO: remove this conditional when deprecated timeout parameter is removed + if not self._event.is_set(): + self._on_timeout() if self._final_result is not _NOT_SET: if self._paging_state is None: return self._final_result else: - return PagedResult(self, self._final_result, timeout) - elif self._final_exception: - raise self._final_exception + return PagedResult(self, self._final_result) else: - self._event.wait(timeout=timeout) - if self._final_result is not _NOT_SET: - if self._paging_state is None: - return self._final_result - else: - return PagedResult(self, self._final_result, timeout) - elif self._final_exception: - raise self._final_exception - else: - raise OperationTimedOut(errors=self._errors, last_host=self._current_host) + raise self._final_exception def get_query_trace(self, max_wait=None): """ @@ -3450,10 +3461,9 @@ class PagedResult(object): response_future = None - def __init__(self, response_future, initial_response, timeout=_NOT_SET): + def __init__(self, response_future, initial_response): self.response_future = response_future self.current_response = iter(initial_response) - self.timeout = timeout def __iter__(self): return self @@ -3466,7 +3476,7 @@ class PagedResult(object): raise self.response_future.start_fetching_next_page() - result = self.response_future.result(self.timeout) + result = self.response_future.result() if self.response_future.has_more_pages: self.current_response = result.current_response else: diff --git a/cassandra/connection.py b/cassandra/connection.py index 464ec792..95858bd8 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -16,6 +16,7 @@ from __future__ import absolute_import # to enable import io from stdlib from collections import defaultdict, deque, namedtuple import errno from functools import wraps, partial +from heapq import heappush, heappop import io import logging import socket @@ -44,7 +45,8 @@ from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessag QueryMessage, ResultMessage, decode_response, InvalidRequestException, SupportedMessage, AuthResponseMessage, AuthChallengeMessage, - AuthSuccessMessage, ProtocolException, MAX_SUPPORTED_VERSION) + AuthSuccessMessage, ProtocolException, + MAX_SUPPORTED_VERSION, RegisterMessage) from cassandra.util import OrderedDict @@ -219,6 +221,7 @@ class Connection(object): self.is_control_connection = is_control_connection self.user_type_map = user_type_map self._push_watchers = defaultdict(set) + self._callbacks = {} self._iobuf = io.BytesIO() if protocol_version >= 3: @@ -233,6 +236,7 @@ class Connection(object): self.highest_request_id = self.max_request_id self.lock = RLock() + self.connected_event = Event() @classmethod def initialize_reactor(self): @@ -250,6 +254,10 @@ class Connection(object): """ pass + @classmethod + def create_timer(cls, timeout, callback): + raise NotImplementedError() + @classmethod def factory(cls, host, timeout, *args, **kwargs): """ @@ -407,11 +415,24 @@ class Connection(object): self.defunct(exc) raise - def register_watcher(self, event_type, callback): - raise NotImplementedError() + def register_watcher(self, event_type, callback, register_timeout=None): + """ + Register a callback for a given event type. + """ + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), + timeout=register_timeout) - def register_watchers(self, type_callback_dict): - raise NotImplementedError() + def register_watchers(self, type_callback_dict, register_timeout=None): + """ + Register multiple callback/event type pairs, expressed as a dict. + """ + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), + timeout=register_timeout) def control_conn_disposed(self): self.is_control_connection = False @@ -907,3 +928,77 @@ class ConnectionHeartbeat(Thread): def _raise_if_stopped(self): if self._shutdown_event.is_set(): raise self.ShutdownException() + + +class Timer(object): + + canceled = False + + def __init__(self, timeout, callback): + self.end = time.time() + timeout + self.callback = callback + if timeout < 0: + self.callback() + + def cancel(self): + self.canceled = True + + def finish(self, time_now): + if self.canceled: + return True + + if time_now >= self.end: + self.callback() + return True + + return False + + +class TimerManager(object): + + def __init__(self): + self._queue = [] + self._new_timers = [] + + def add_timer(self, timer): + """ + called from client thread with a Timer object + """ + self._new_timers.append((timer.end, timer)) + + def service_timeouts(self): + """ + run callbacks on all expired timers + Called from the event thread + :return: next end time, or None + """ + queue = self._queue + new_timers = self._new_timers + while new_timers: + heappush(queue, new_timers.pop()) + + now = time.time() + while queue: + try: + timer = queue[0][1] + if timer.finish(now): + heappop(queue) + else: + return timer.end + except Exception: + log.exception("Exception while servicing timeout callback: ") + + @property + def next_timeout(self): + try: + return self._queue[0][0] + except IndexError: + pass + + @property + def next_offset(self): + try: + next_end = self._queue[0][0] + return next_end - time.time() + except IndexError: + pass diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index 616b5c00..fae87a73 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -19,6 +19,7 @@ import os import socket import sys from threading import Event, Lock, Thread +import time import weakref from six.moves import range @@ -35,8 +36,9 @@ try: except ImportError: ssl = None # NOQA -from cassandra.connection import (Connection, ConnectionShutdown, NONBLOCKING) -from cassandra.protocol import RegisterMessage +from cassandra.connection import (Connection, ConnectionShutdown, + ConnectionException, NONBLOCKING, + Timer, TimerManager) log = logging.getLogger(__name__) @@ -52,15 +54,17 @@ def _cleanup(loop_weakref): class AsyncoreLoop(object): + def __init__(self): self._pid = os.getpid() self._loop_lock = Lock() self._started = False self._shutdown = False - self._conns_lock = Lock() - self._conns = WeakSet() self._thread = None + + self._timers = TimerManager() + atexit.register(partial(_cleanup, weakref.ref(self))) def maybe_start(self): @@ -83,24 +87,22 @@ class AsyncoreLoop(object): def _run_loop(self): log.debug("Starting asyncore event loop") with self._loop_lock: - while True: + while not self._shutdown: try: - asyncore.loop(timeout=0.001, use_poll=True, count=1000) + asyncore.loop(timeout=0.001, use_poll=True, count=100) + self._timers.service_timeouts() + if not asyncore.socket_map: + time.sleep(0.005) except Exception: log.debug("Asyncore event loop stopped unexepectedly", exc_info=True) break - - if self._shutdown: - break - - with self._conns_lock: - if len(self._conns) == 0: - break - self._started = False log.debug("Asyncore event loop ended") + def add_timer(self, timer): + self._timers.add_timer(timer) + def _cleanup(self): self._shutdown = True if not self._thread: @@ -115,14 +117,6 @@ class AsyncoreLoop(object): log.debug("Event loop thread was joined") - def connection_created(self, connection): - with self._conns_lock: - self._conns.add(connection) - - def connection_destroyed(self, connection): - with self._conns_lock: - self._conns.discard(connection) - class AsyncoreConnection(Connection, asyncore.dispatcher): """ @@ -152,18 +146,19 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): cls._loop._cleanup() cls._loop = None + @classmethod + def create_timer(cls, timeout, callback): + timer = Timer(timeout, callback) + cls._loop.add_timer(timer) + return timer + def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) asyncore.dispatcher.__init__(self) - self.connected_event = Event() - - self._callbacks = {} self.deque = deque() self.deque_lock = Lock() - self._loop.connection_created(self) - self._connect_socket() asyncore.dispatcher.__init__(self, self._socket) @@ -187,8 +182,6 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): asyncore.dispatcher.close(self) log.debug("Closed socket to %s", self.host) - self._loop.connection_destroyed(self) - if not self.is_defunct: self.error_all_callbacks( ConnectionShutdown("Connection to %s was closed" % self.host)) @@ -267,14 +260,3 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): def readable(self): return self._readable or (self.is_control_connection and not (self.is_defunct or self.is_closed)) - - def register_watcher(self, event_type, callback, register_timeout=None): - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=[event_type]), timeout=register_timeout) - - def register_watchers(self, type_callback_dict, register_timeout=None): - for event_type, callback in type_callback_dict.items(): - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=type_callback_dict.keys()), timeout=register_timeout) diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py index e65d5aa5..b0206f43 100644 --- a/cassandra/io/eventletreactor.py +++ b/cassandra/io/eventletreactor.py @@ -16,7 +16,6 @@ # Originally derived from MagnetoDB source: # https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py -from collections import defaultdict from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL import eventlet from eventlet.green import select, socket @@ -24,11 +23,12 @@ from eventlet.queue import Queue from functools import partial import logging import os -from six.moves import xrange from threading import Event +import time -from cassandra.connection import Connection, ConnectionShutdown -from cassandra.protocol import RegisterMessage +from six.moves import xrange + +from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) @@ -52,19 +52,45 @@ class EventletConnection(Connection): _socket_impl = eventlet.green.socket _ssl_impl = eventlet.green.ssl + _timers = None + _timeout_watcher = None + _new_timer = None + @classmethod def initialize_reactor(cls): eventlet.monkey_patch() + if not cls._timers: + cls._timers = TimerManager() + cls._timeout_watcher = eventlet.spawn(cls.service_timeouts) + cls._new_timer = Event() + + @classmethod + def create_timer(cls, timeout, callback): + timer = Timer(timeout, callback) + cls._timers.add_timer(timer) + cls._new_timer.set() + return timer + + @classmethod + def service_timeouts(cls): + """ + cls._timeout_watcher runs in this loop forever. + It is usually waiting for the next timeout on the cls._new_timer Event. + When new timers are added, that event is set so that the watcher can + wake up and possibly set an earlier timeout. + """ + timer_manager = cls._timers + while True: + next_end = timer_manager.service_timeouts() + sleep_time = max(next_end - time.time(), 0) if next_end else 10000 + cls._new_timer.wait(sleep_time) + cls._new_timer.clear() def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) - self.connected_event = Event() self._write_queue = Queue() - self._callbacks = {} - self._push_watchers = defaultdict(set) - self._connect_socket() self._read_watcher = eventlet.spawn(lambda: self.handle_read()) @@ -142,16 +168,3 @@ class EventletConnection(Connection): chunk_size = self.out_buffer_size for i in xrange(0, len(data), chunk_size): self._write_queue.put(data[i:i + chunk_size]) - - def register_watcher(self, event_type, callback, register_timeout=None): - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=[event_type]), - timeout=register_timeout) - - def register_watchers(self, type_callback_dict, register_timeout=None): - for event_type, callback in type_callback_dict.items(): - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=type_callback_dict.keys()), - timeout=register_timeout) diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py index 5d8553ce..60a3f2d0 100644 --- a/cassandra/io/geventreactor.py +++ b/cassandra/io/geventreactor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import gevent -from gevent.event import Event +import gevent.event from gevent.queue import Queue from gevent import select, socket import gevent.ssl @@ -21,13 +21,13 @@ from collections import defaultdict from functools import partial import logging import os +import time -from six.moves import xrange +from six.moves import range from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL -from cassandra.connection import Connection, ConnectionShutdown -from cassandra.protocol import RegisterMessage +from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) @@ -51,15 +51,39 @@ class GeventConnection(Connection): _socket_impl = gevent.socket _ssl_impl = gevent.ssl + _timers = None + _timeout_watcher = None + _new_timer = None + + @classmethod + def initialize_reactor(cls): + if not cls._timers: + cls._timers = TimerManager() + cls._timeout_watcher = gevent.spawn(cls.service_timeouts) + cls._new_timer = gevent.event.Event() + + @classmethod + def create_timer(cls, timeout, callback): + timer = Timer(timeout, callback) + cls._timers.add_timer(timer) + cls._new_timer.set() + return timer + + @classmethod + def service_timeouts(cls): + timer_manager = cls._timers + timer_event = cls._new_timer + while True: + next_end = timer_manager.service_timeouts() + sleep_time = max(next_end - time.time(), 0) if next_end else 10000 + timer_event.wait(sleep_time) + timer_event.clear() + def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) - self.connected_event = Event() self._write_queue = Queue() - self._callbacks = {} - self._push_watchers = defaultdict(set) - self._connect_socket() self._read_watcher = gevent.spawn(self.handle_read) @@ -142,18 +166,5 @@ class GeventConnection(Connection): def push(self, data): chunk_size = self.out_buffer_size - for i in xrange(0, len(data), chunk_size): + for i in range(0, len(data), chunk_size): self._write_queue.put(data[i:i + chunk_size]) - - def register_watcher(self, event_type, callback, register_timeout=None): - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=[event_type]), - timeout=register_timeout) - - def register_watchers(self, type_callback_dict, register_timeout=None): - for event_type, callback in type_callback_dict.items(): - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=type_callback_dict.keys()), - timeout=register_timeout) diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index 00127a1f..af0c3a60 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -17,13 +17,14 @@ from functools import partial import logging import os import socket -from threading import Event, Lock, Thread +import ssl +from threading import Lock, Thread import weakref -from six.moves import xrange +from six.moves import range -from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING, ssl -from cassandra.protocol import RegisterMessage +from cassandra.connection import (Connection, ConnectionShutdown, + NONBLOCKING, Timer, TimerManager) try: import cassandra.io.libevwrapper as libev except ImportError: @@ -44,7 +45,6 @@ def _cleanup(loop_weakref): loop = loop_weakref() except ReferenceError: return - loop._cleanup() @@ -79,10 +79,10 @@ class LibevLoop(object): self._loop.unref() self._preparer.start() - atexit.register(partial(_cleanup, weakref.ref(self))) + self._timers = TimerManager() + self._loop_timer = libev.Timer(self._loop, self._on_loop_timer) - def notify(self): - self._notifier.send() + atexit.register(partial(_cleanup, weakref.ref(self))) def maybe_start(self): should_start = False @@ -127,6 +127,7 @@ class LibevLoop(object): conn._read_watcher.stop() del conn._read_watcher + self.notify() # wake the timer watcher log.debug("Waiting for event loop thread to join...") self._thread.join(timeout=1.0) if self._thread.is_alive(): @@ -137,6 +138,24 @@ class LibevLoop(object): log.debug("Event loop thread was joined") self._loop = None + def add_timer(self, timer): + self._timers.add_timer(timer) + self._notifier.send() # wake up in case this timer is earlier + + def _update_timer(self): + if not self._shutdown: + self._timers.service_timeouts() + offset = self._timers.next_offset or 100000 # none pending; will be updated again when something new happens + self._loop_timer.start(offset) + else: + self._loop_timer.stop() + + def _on_loop_timer(self): + self._timers.service_timeouts() + + def notify(self): + self._notifier.send() + def connection_created(self, conn): with self._conn_set_lock: new_live_conns = self._live_conns.copy() @@ -199,6 +218,9 @@ class LibevLoop(object): changed = True + # TODO: update to do connection management, timer updates through dedicated async 'notifier' callbacks + self._update_timer() + if changed: self._notifier.send() @@ -229,12 +251,15 @@ class LibevConnection(Connection): cls._libevloop._cleanup() cls._libevloop = None + @classmethod + def create_timer(cls, timeout, callback): + timer = Timer(timeout, callback) + cls._libevloop.add_timer(timer) + return timer + def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) - self.connected_event = Event() - - self._callbacks = {} self.deque = deque() self._deque_lock = Lock() self._connect_socket() @@ -332,7 +357,7 @@ class LibevConnection(Connection): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] - for i in xrange(0, len(data), sabs): + for i in range(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] @@ -340,14 +365,3 @@ class LibevConnection(Connection): with self._deque_lock: self.deque.extend(chunks) self._libevloop.notify() - - def register_watcher(self, event_type, callback, register_timeout=None): - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=[event_type]), timeout=register_timeout) - - def register_watchers(self, type_callback_dict, register_timeout=None): - for event_type, callback in type_callback_dict.items(): - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=type_callback_dict.keys()), timeout=register_timeout) diff --git a/cassandra/io/libevwrapper.c b/cassandra/io/libevwrapper.c index cbac83b2..99e1df30 100644 --- a/cassandra/io/libevwrapper.c +++ b/cassandra/io/libevwrapper.c @@ -451,6 +451,131 @@ static PyTypeObject libevwrapper_PrepareType = { (initproc)Prepare_init, /* tp_init */ }; +typedef struct libevwrapper_Timer { + PyObject_HEAD + struct ev_timer timer; + struct libevwrapper_Loop *loop; + PyObject *callback; +} libevwrapper_Timer; + +static void +Timer_dealloc(libevwrapper_Timer *self) { + Py_XDECREF(self->loop); + Py_XDECREF(self->callback); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static void timer_callback(struct ev_loop *loop, ev_timer *watcher, int revents) { + libevwrapper_Timer *self = watcher->data; + + PyObject *result = NULL; + PyGILState_STATE gstate; + + gstate = PyGILState_Ensure(); + result = PyObject_CallFunction(self->callback, NULL); + if (!result) { + PyErr_WriteUnraisable(self->callback); + } + Py_XDECREF(result); + + PyGILState_Release(gstate); +} + +static int +Timer_init(libevwrapper_Timer *self, PyObject *args, PyObject *kwds) { + PyObject *callback; + PyObject *loop; + + if (!PyArg_ParseTuple(args, "OO", &loop, &callback)) { + return -1; + } + + if (loop) { + Py_INCREF(loop); + self->loop = (libevwrapper_Loop *)loop; + } else { + return -1; + } + + if (callback) { + if (!PyCallable_Check(callback)) { + PyErr_SetString(PyExc_TypeError, "callback parameter must be callable"); + Py_XDECREF(loop); + return -1; + } + Py_INCREF(callback); + self->callback = callback; + } + ev_init(&self->timer, timer_callback); + self->timer.data = self; + return 0; +} + +static PyObject * +Timer_start(libevwrapper_Timer *self, PyObject *args) { + double timeout; + if (!PyArg_ParseTuple(args, "d", &timeout)) { + return NULL; + } + /* some tiny non-zero number to avoid zero, and + make it run immediately for negative timeouts */ + self->timer.repeat = fmax(timeout, 0.000000001); + ev_timer_again(self->loop->loop, &self->timer); + Py_RETURN_NONE; +} + +static PyObject * +Timer_stop(libevwrapper_Timer *self, PyObject *args) { + ev_timer_stop(self->loop->loop, &self->timer); + Py_RETURN_NONE; +} + +static PyMethodDef Timer_methods[] = { + {"start", (PyCFunction)Timer_start, METH_VARARGS, "Start the Timer watcher"}, + {"stop", (PyCFunction)Timer_stop, METH_NOARGS, "Stop the Timer watcher"}, + {NULL} /* Sentinal */ +}; + +static PyTypeObject libevwrapper_TimerType = { + PyVarObject_HEAD_INIT(NULL, 0) + "cassandra.io.libevwrapper.Timer", /*tp_name*/ + sizeof(libevwrapper_Timer), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)Timer_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + "Timer objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Timer_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)Timer_init, /* tp_init */ +}; + + static PyMethodDef module_methods[] = { {NULL} /* Sentinal */ }; @@ -500,6 +625,10 @@ initlibevwrapper(void) if (PyType_Ready(&libevwrapper_AsyncType) < 0) INITERROR; + libevwrapper_TimerType.tp_new = PyType_GenericNew; + if (PyType_Ready(&libevwrapper_TimerType) < 0) + INITERROR; + # if PY_MAJOR_VERSION >= 3 module = PyModule_Create(&moduledef); # else @@ -532,6 +661,10 @@ initlibevwrapper(void) if (PyModule_AddObject(module, "Async", (PyObject *)&libevwrapper_AsyncType) == -1) INITERROR; + Py_INCREF(&libevwrapper_TimerType); + if (PyModule_AddObject(module, "Timer", (PyObject *)&libevwrapper_TimerType) == -1) + INITERROR; + if (!PyEval_ThreadsInitialized()) { PyEval_InitThreads(); } diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py index 1de8de15..b02fb6ab 100644 --- a/cassandra/io/twistedreactor.py +++ b/cassandra/io/twistedreactor.py @@ -15,15 +15,15 @@ Module that implements an event loop based on twisted ( https://twistedmatrix.com ). """ -from twisted.internet import reactor, protocol -from threading import Event, Thread, Lock +import atexit from functools import partial import logging +from threading import Thread, Lock +import time +from twisted.internet import reactor, protocol import weakref -import atexit -from cassandra.connection import Connection, ConnectionShutdown -from cassandra.protocol import RegisterMessage +from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) @@ -108,9 +108,12 @@ class TwistedLoop(object): _lock = None _thread = None + _timeout_task = None + _timeout = None def __init__(self): self._lock = Lock() + self._timers = TimerManager() def maybe_start(self): with self._lock: @@ -132,6 +135,27 @@ class TwistedLoop(object): "Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") + def add_timer(self, timer): + self._timers.add_timer(timer) + # callFromThread to schedule from the loop thread, where + # the timeout task can safely be modified + reactor.callFromThread(self._schedule_timeout, timer.end) + + def _schedule_timeout(self, next_timeout): + if next_timeout: + delay = max(next_timeout - time.time(), 0) + if self._timeout_task and self._timeout_task.active(): + if next_timeout < self._timeout: + self._timeout_task.reset(delay) + self._timeout = next_timeout + else: + self._timeout_task = reactor.callLater(delay, self._on_loop_timer) + self._timeout = next_timeout + + def _on_loop_timer(self): + self._timers.service_timeouts() + self._schedule_timeout(self._timers.next_timeout) + class TwistedConnection(Connection): """ @@ -146,6 +170,12 @@ class TwistedConnection(Connection): if not cls._loop: cls._loop = TwistedLoop() + @classmethod + def create_timer(cls, timeout, callback): + timer = Timer(timeout, callback) + cls._loop.add_timer(timer) + return timer + def __init__(self, *args, **kwargs): """ Initialization method. @@ -157,11 +187,9 @@ class TwistedConnection(Connection): """ Connection.__init__(self, *args, **kwargs) - self.connected_event = Event() self.is_closed = True self.connector = None - self._callbacks = {} reactor.callFromThread(self.add_connection) self._loop.maybe_start() @@ -218,22 +246,3 @@ class TwistedConnection(Connection): the event loop when it gets the chance. """ reactor.callFromThread(self.connector.transport.write, data) - - def register_watcher(self, event_type, callback, register_timeout=None): - """ - Register a callback for a given event type. - """ - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=[event_type]), - timeout=register_timeout) - - def register_watchers(self, type_callback_dict, register_timeout=None): - """ - Register multiple callback/event type pairs, expressed as a dict. - """ - for event_type, callback in type_callback_dict.items(): - self._push_watchers[event_type].add(callback) - self.wait_for_response( - RegisterMessage(event_list=type_callback_dict.keys()), - timeout=register_timeout) diff --git a/cassandra/query.py b/cassandra/query.py index 6709cbef..21b668df 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -917,14 +917,14 @@ class QueryTrace(object): break def _execute(self, query, parameters, time_spent, max_wait): + timeout = (max_wait - time_spent) if max_wait is not None else None + future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) # in case the user switched the row factory, set it to namedtuple for this query - future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None) future.row_factory = named_tuple_factory future.send_request() - timeout = (max_wait - time_spent) if max_wait is not None else None try: - return future.result(timeout=timeout) + return future.result() except OperationTimedOut: raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) diff --git a/docs/object_mapper.rst b/docs/object_mapper.rst index 26d78a09..4e389940 100644 --- a/docs/object_mapper.rst +++ b/docs/object_mapper.rst @@ -48,7 +48,7 @@ Getting Started from cassandra.cqlengine import columns from cassandra.cqlengine import connection from datetime import datetime - from cassandra.cqlengine.management import create_keyspace, sync_table + from cassandra.cqlengine.management import sync_table from cassandra.cqlengine.models import Model #first, define a model diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index e5be6075..268e19d5 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -244,10 +244,7 @@ class ConnectionTest(unittest.TestCase): Ensure the following methods throw NIE's. If not, come back and test them. """ c = self.make_connection() - self.assertRaises(NotImplementedError, c.close) - self.assertRaises(NotImplementedError, c.register_watcher, None, None) - self.assertRaises(NotImplementedError, c.register_watchers, None) def test_set_keyspace_blocking(self): c = self.make_connection() diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 92351a9d..eea43f75 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -47,7 +47,7 @@ class ResponseFutureTests(unittest.TestCase): def make_response_future(self, session): query = SimpleStatement("SELECT * FROM foo") message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - return ResponseFuture(session, message, query) + return ResponseFuture(session, message, query, 1) def make_mock_response(self, results): return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=results, paging_state=None) @@ -122,7 +122,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_read_timeout.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query) + rf = ResponseFuture(session, message, query, 1) rf.send_request() result = Mock(spec=ReadTimeoutErrorMessage, info={}) @@ -137,7 +137,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_write_timeout.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query) + rf = ResponseFuture(session, message, query, 1) rf.send_request() result = Mock(spec=WriteTimeoutErrorMessage, info={}) @@ -151,7 +151,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query) + rf = ResponseFuture(session, message, query, 1) rf.send_request() result = Mock(spec=UnavailableErrorMessage, info={}) @@ -165,7 +165,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_unavailable.return_value = (RetryPolicy.IGNORE, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query) + rf = ResponseFuture(session, message, query, 1) rf.send_request() result = Mock(spec=UnavailableErrorMessage, info={}) @@ -184,7 +184,7 @@ class ResponseFutureTests(unittest.TestCase): connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) - rf = ResponseFuture(session, message, query) + rf = ResponseFuture(session, message, query, 1) rf.send_request() rf.session._pools.get.assert_called_once_with('ip1') @@ -279,7 +279,7 @@ class ResponseFutureTests(unittest.TestCase): session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2'] session._pools.get.return_value.is_shutdown = True - rf = ResponseFuture(session, Mock(), Mock()) + rf = ResponseFuture(session, Mock(), Mock(), 1) rf.send_request() self.assertRaises(NoHostAvailable, rf.result) @@ -354,7 +354,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query) + rf = ResponseFuture(session, message, query, 1) rf.send_request() rf.add_errback(self.assertIsInstance, Exception) @@ -401,7 +401,7 @@ class ResponseFutureTests(unittest.TestCase): query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query) + rf = ResponseFuture(session, message, query, 1) rf.send_request() callback = Mock() @@ -431,7 +431,7 @@ class ResponseFutureTests(unittest.TestCase): message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) # test errback - rf = ResponseFuture(session, message, query) + rf = ResponseFuture(session, message, query, 1) rf.send_request() rf.add_callbacks( @@ -443,7 +443,7 @@ class ResponseFutureTests(unittest.TestCase): self.assertRaises(Exception, rf.result) # test callback - rf = ResponseFuture(session, message, query) + rf = ResponseFuture(session, message, query, 1) rf.send_request() callback = Mock()