Revert "Revert "Merge pull request #298 from datastax/PYTHON-108""

This reverts commit dfa91b8bd5.

Conflicts:
	cassandra/cluster.py
	cassandra/connection.py
	cassandra/io/asyncorereactor.py
	cassandra/io/eventletreactor.py
	cassandra/io/geventreactor.py
	cassandra/io/libevreactor.py
	cassandra/io/twistedreactor.py
	cassandra/query.py
This commit is contained in:
Adam Holmberg 2015-06-19 10:08:48 -05:00
parent da7727b905
commit ee2243c217
12 changed files with 460 additions and 196 deletions

View File

@ -1442,8 +1442,7 @@ class Session(object):
"""
A default timeout, measured in seconds, for queries executed through
:meth:`.execute()` or :meth:`.execute_async()`. This default may be
overridden with the `timeout` parameter for either of those methods
or the `timeout` parameter for :meth:`.ResponseFuture.result()`.
overridden with the `timeout` parameter for either of those methods.
Setting this to :const:`None` will cause no timeouts to be set by default.
@ -1581,17 +1580,14 @@ class Session(object):
If `query` is a Statement with its own custom_payload. The message payload
will be a union of the two, with the values specified here taking precedence.
"""
if timeout is _NOT_SET:
timeout = self.default_timeout
if trace and not isinstance(query, Statement):
raise TypeError(
"The query argument must be an instance of a subclass of "
"cassandra.query.Statement when trace=True")
future = self.execute_async(query, parameters, trace, custom_payload)
future = self.execute_async(query, parameters, trace, custom_payload, timeout)
try:
result = future.result(timeout)
result = future.result()
finally:
if trace:
try:
@ -1601,7 +1597,7 @@ class Session(object):
return result
def execute_async(self, query, parameters=None, trace=False, custom_payload=None):
def execute_async(self, query, parameters=None, trace=False, custom_payload=None, timeout=_NOT_SET):
"""
Execute the given query and return a :class:`~.ResponseFuture` object
which callbacks may be attached to for asynchronous response
@ -1646,11 +1642,14 @@ class Session(object):
... log.exception("Operation failed:")
"""
future = self._create_response_future(query, parameters, trace, custom_payload)
if timeout is _NOT_SET:
timeout = self.default_timeout
future = self._create_response_future(query, parameters, trace, custom_payload, timeout)
future.send_request()
return future
def _create_response_future(self, query, parameters, trace, custom_payload):
def _create_response_future(self, query, parameters, trace, custom_payload, timeout):
""" Returns the ResponseFuture before calling send_request() on it """
prepared_statement = None
@ -1704,7 +1703,7 @@ class Session(object):
message.update_custom_payload(custom_payload)
return ResponseFuture(
self, message, query, self.default_timeout, metrics=self._metrics,
self, message, query, timeout, metrics=self._metrics,
prepared_statement=prepared_statement)
def prepare(self, query, custom_payload=None):
@ -1737,11 +1736,10 @@ class Session(object):
message. See :ref:`custom_payload`.
"""
message = PrepareMessage(query=query)
message.custom_payload = custom_payload
future = ResponseFuture(self, message, query=None)
future = ResponseFuture(self, message, query=None, timeout=self.default_timeout)
try:
future.send_request()
query_id, column_metadata, pk_indexes = future.result(self.default_timeout)
query_id, column_metadata, pk_indexes = future.result()
except Exception:
log.exception("Error preparing query:")
raise
@ -1767,7 +1765,7 @@ class Session(object):
futures = []
for host in self._pools.keys():
if host != excluded_host and host.is_up:
future = ResponseFuture(self, PrepareMessage(query=query), None)
future = ResponseFuture(self, PrepareMessage(query=query), None, self.default_timeout)
# we don't care about errors preparing against specific hosts,
# since we can always prepare them as needed when the prepared
@ -1788,7 +1786,7 @@ class Session(object):
for host, future in futures:
try:
future.result(self.default_timeout)
future.result()
except Exception:
log.exception("Error preparing query for host %s:", host)
@ -2832,13 +2830,14 @@ class ResponseFuture(object):
_paging_state = None
_custom_payload = None
_warnings = None
_timer = None
def __init__(self, session, message, query, default_timeout=None, metrics=None, prepared_statement=None):
def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None):
self.session = session
self.row_factory = session.row_factory
self.message = message
self.query = query
self.default_timeout = default_timeout
self.timeout = timeout
self._metrics = metrics
self.prepared_statement = prepared_statement
self._callback_lock = Lock()
@ -2849,6 +2848,18 @@ class ResponseFuture(object):
self._errors = {}
self._callbacks = []
self._errbacks = []
self._start_timer()
def _start_timer(self):
if self.timeout is not None:
self._timer = self.session.cluster.connection_class.create_timer(self.timeout, self._on_timeout)
def _cancel_timer(self):
if self._timer:
self._timer.cancel()
def _on_timeout(self):
self._set_final_exception(OperationTimedOut(self._errors, self._current_host))
def _make_query_plan(self):
# convert the list/generator/etc to an iterator so that subsequent
@ -2973,6 +2984,7 @@ class ResponseFuture(object):
self._event.clear()
self._final_result = _NOT_SET
self._final_exception = None
self._start_timer()
self.send_request()
def _reprepare(self, prepare_message):
@ -3187,6 +3199,7 @@ class ResponseFuture(object):
"statement on host %s: %s" % (self._current_host, response)))
def _set_final_result(self, response):
self._cancel_timer()
if self._metrics is not None:
self._metrics.request_timer.addValue(time.time() - self._start_time)
@ -3201,6 +3214,7 @@ class ResponseFuture(object):
fn(response, *args, **kwargs)
def _set_final_exception(self, response):
self._cancel_timer()
if self._metrics is not None:
self._metrics.request_timer.addValue(time.time() - self._start_time)
@ -3244,6 +3258,11 @@ class ResponseFuture(object):
encountered. If the final result or error has not been set
yet, this method will block until that time.
.. versionchanged:: 2.6.0
**`timeout` is deprecated. Use timeout in the Session execute functions instead.
The following description applies to deprecated behavior:**
You may set a timeout (in seconds) with the `timeout` parameter.
By default, the :attr:`~.default_timeout` for the :class:`.Session`
this was created through will be used for the timeout on this
@ -3257,11 +3276,6 @@ class ResponseFuture(object):
This is a client-side timeout. For more information
about server-side coordinator timeouts, see :class:`.policies.RetryPolicy`.
**Important**: This timeout currently has no effect on callbacks registered
on a :class:`~.ResponseFuture` through :meth:`.ResponseFuture.add_callback` or
:meth:`.ResponseFuture.add_errback`; even if a query exceeds this default
timeout, neither the registered callback or errback will be called.
Example usage::
>>> future = session.execute_async("SELECT * FROM mycf")
@ -3275,27 +3289,24 @@ class ResponseFuture(object):
... log.exception("Operation failed:")
"""
if timeout is _NOT_SET:
timeout = self.default_timeout
if timeout is not _NOT_SET:
msg = "ResponseFuture.result timeout argument is deprecated. Specify the request timeout via Session.execute[_async]."
warnings.warn(msg, DeprecationWarning)
log.warning(msg)
else:
timeout = None
self._event.wait(timeout)
# TODO: remove this conditional when deprecated timeout parameter is removed
if not self._event.is_set():
self._on_timeout()
if self._final_result is not _NOT_SET:
if self._paging_state is None:
return self._final_result
else:
return PagedResult(self, self._final_result, timeout)
elif self._final_exception:
raise self._final_exception
return PagedResult(self, self._final_result)
else:
self._event.wait(timeout=timeout)
if self._final_result is not _NOT_SET:
if self._paging_state is None:
return self._final_result
else:
return PagedResult(self, self._final_result, timeout)
elif self._final_exception:
raise self._final_exception
else:
raise OperationTimedOut(errors=self._errors, last_host=self._current_host)
raise self._final_exception
def get_query_trace(self, max_wait=None):
"""
@ -3450,10 +3461,9 @@ class PagedResult(object):
response_future = None
def __init__(self, response_future, initial_response, timeout=_NOT_SET):
def __init__(self, response_future, initial_response):
self.response_future = response_future
self.current_response = iter(initial_response)
self.timeout = timeout
def __iter__(self):
return self
@ -3466,7 +3476,7 @@ class PagedResult(object):
raise
self.response_future.start_fetching_next_page()
result = self.response_future.result(self.timeout)
result = self.response_future.result()
if self.response_future.has_more_pages:
self.current_response = result.current_response
else:

View File

@ -16,6 +16,7 @@ from __future__ import absolute_import # to enable import io from stdlib
from collections import defaultdict, deque, namedtuple
import errno
from functools import wraps, partial
from heapq import heappush, heappop
import io
import logging
import socket
@ -44,7 +45,8 @@ from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessag
QueryMessage, ResultMessage, decode_response,
InvalidRequestException, SupportedMessage,
AuthResponseMessage, AuthChallengeMessage,
AuthSuccessMessage, ProtocolException, MAX_SUPPORTED_VERSION)
AuthSuccessMessage, ProtocolException,
MAX_SUPPORTED_VERSION, RegisterMessage)
from cassandra.util import OrderedDict
@ -219,6 +221,7 @@ class Connection(object):
self.is_control_connection = is_control_connection
self.user_type_map = user_type_map
self._push_watchers = defaultdict(set)
self._callbacks = {}
self._iobuf = io.BytesIO()
if protocol_version >= 3:
@ -233,6 +236,7 @@ class Connection(object):
self.highest_request_id = self.max_request_id
self.lock = RLock()
self.connected_event = Event()
@classmethod
def initialize_reactor(self):
@ -250,6 +254,10 @@ class Connection(object):
"""
pass
@classmethod
def create_timer(cls, timeout, callback):
raise NotImplementedError()
@classmethod
def factory(cls, host, timeout, *args, **kwargs):
"""
@ -407,11 +415,24 @@ class Connection(object):
self.defunct(exc)
raise
def register_watcher(self, event_type, callback):
raise NotImplementedError()
def register_watcher(self, event_type, callback, register_timeout=None):
"""
Register a callback for a given event type.
"""
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=[event_type]),
timeout=register_timeout)
def register_watchers(self, type_callback_dict):
raise NotImplementedError()
def register_watchers(self, type_callback_dict, register_timeout=None):
"""
Register multiple callback/event type pairs, expressed as a dict.
"""
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=type_callback_dict.keys()),
timeout=register_timeout)
def control_conn_disposed(self):
self.is_control_connection = False
@ -907,3 +928,77 @@ class ConnectionHeartbeat(Thread):
def _raise_if_stopped(self):
if self._shutdown_event.is_set():
raise self.ShutdownException()
class Timer(object):
canceled = False
def __init__(self, timeout, callback):
self.end = time.time() + timeout
self.callback = callback
if timeout < 0:
self.callback()
def cancel(self):
self.canceled = True
def finish(self, time_now):
if self.canceled:
return True
if time_now >= self.end:
self.callback()
return True
return False
class TimerManager(object):
def __init__(self):
self._queue = []
self._new_timers = []
def add_timer(self, timer):
"""
called from client thread with a Timer object
"""
self._new_timers.append((timer.end, timer))
def service_timeouts(self):
"""
run callbacks on all expired timers
Called from the event thread
:return: next end time, or None
"""
queue = self._queue
new_timers = self._new_timers
while new_timers:
heappush(queue, new_timers.pop())
now = time.time()
while queue:
try:
timer = queue[0][1]
if timer.finish(now):
heappop(queue)
else:
return timer.end
except Exception:
log.exception("Exception while servicing timeout callback: ")
@property
def next_timeout(self):
try:
return self._queue[0][0]
except IndexError:
pass
@property
def next_offset(self):
try:
next_end = self._queue[0][0]
return next_end - time.time()
except IndexError:
pass

View File

@ -19,6 +19,7 @@ import os
import socket
import sys
from threading import Event, Lock, Thread
import time
import weakref
from six.moves import range
@ -35,8 +36,9 @@ try:
except ImportError:
ssl = None # NOQA
from cassandra.connection import (Connection, ConnectionShutdown, NONBLOCKING)
from cassandra.protocol import RegisterMessage
from cassandra.connection import (Connection, ConnectionShutdown,
ConnectionException, NONBLOCKING,
Timer, TimerManager)
log = logging.getLogger(__name__)
@ -52,15 +54,17 @@ def _cleanup(loop_weakref):
class AsyncoreLoop(object):
def __init__(self):
self._pid = os.getpid()
self._loop_lock = Lock()
self._started = False
self._shutdown = False
self._conns_lock = Lock()
self._conns = WeakSet()
self._thread = None
self._timers = TimerManager()
atexit.register(partial(_cleanup, weakref.ref(self)))
def maybe_start(self):
@ -83,24 +87,22 @@ class AsyncoreLoop(object):
def _run_loop(self):
log.debug("Starting asyncore event loop")
with self._loop_lock:
while True:
while not self._shutdown:
try:
asyncore.loop(timeout=0.001, use_poll=True, count=1000)
asyncore.loop(timeout=0.001, use_poll=True, count=100)
self._timers.service_timeouts()
if not asyncore.socket_map:
time.sleep(0.005)
except Exception:
log.debug("Asyncore event loop stopped unexepectedly", exc_info=True)
break
if self._shutdown:
break
with self._conns_lock:
if len(self._conns) == 0:
break
self._started = False
log.debug("Asyncore event loop ended")
def add_timer(self, timer):
self._timers.add_timer(timer)
def _cleanup(self):
self._shutdown = True
if not self._thread:
@ -115,14 +117,6 @@ class AsyncoreLoop(object):
log.debug("Event loop thread was joined")
def connection_created(self, connection):
with self._conns_lock:
self._conns.add(connection)
def connection_destroyed(self, connection):
with self._conns_lock:
self._conns.discard(connection)
class AsyncoreConnection(Connection, asyncore.dispatcher):
"""
@ -152,18 +146,19 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
cls._loop._cleanup()
cls._loop = None
@classmethod
def create_timer(cls, timeout, callback):
timer = Timer(timeout, callback)
cls._loop.add_timer(timer)
return timer
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
asyncore.dispatcher.__init__(self)
self.connected_event = Event()
self._callbacks = {}
self.deque = deque()
self.deque_lock = Lock()
self._loop.connection_created(self)
self._connect_socket()
asyncore.dispatcher.__init__(self, self._socket)
@ -187,8 +182,6 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
asyncore.dispatcher.close(self)
log.debug("Closed socket to %s", self.host)
self._loop.connection_destroyed(self)
if not self.is_defunct:
self.error_all_callbacks(
ConnectionShutdown("Connection to %s was closed" % self.host))
@ -267,14 +260,3 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
def readable(self):
return self._readable or (self.is_control_connection and not (self.is_defunct or self.is_closed))
def register_watcher(self, event_type, callback, register_timeout=None):
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=[event_type]), timeout=register_timeout)
def register_watchers(self, type_callback_dict, register_timeout=None):
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=type_callback_dict.keys()), timeout=register_timeout)

View File

@ -16,7 +16,6 @@
# Originally derived from MagnetoDB source:
# https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py
from collections import defaultdict
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
import eventlet
from eventlet.green import select, socket
@ -24,11 +23,12 @@ from eventlet.queue import Queue
from functools import partial
import logging
import os
from six.moves import xrange
from threading import Event
import time
from cassandra.connection import Connection, ConnectionShutdown
from cassandra.protocol import RegisterMessage
from six.moves import xrange
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
log = logging.getLogger(__name__)
@ -52,19 +52,45 @@ class EventletConnection(Connection):
_socket_impl = eventlet.green.socket
_ssl_impl = eventlet.green.ssl
_timers = None
_timeout_watcher = None
_new_timer = None
@classmethod
def initialize_reactor(cls):
eventlet.monkey_patch()
if not cls._timers:
cls._timers = TimerManager()
cls._timeout_watcher = eventlet.spawn(cls.service_timeouts)
cls._new_timer = Event()
@classmethod
def create_timer(cls, timeout, callback):
timer = Timer(timeout, callback)
cls._timers.add_timer(timer)
cls._new_timer.set()
return timer
@classmethod
def service_timeouts(cls):
"""
cls._timeout_watcher runs in this loop forever.
It is usually waiting for the next timeout on the cls._new_timer Event.
When new timers are added, that event is set so that the watcher can
wake up and possibly set an earlier timeout.
"""
timer_manager = cls._timers
while True:
next_end = timer_manager.service_timeouts()
sleep_time = max(next_end - time.time(), 0) if next_end else 10000
cls._new_timer.wait(sleep_time)
cls._new_timer.clear()
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self._write_queue = Queue()
self._callbacks = {}
self._push_watchers = defaultdict(set)
self._connect_socket()
self._read_watcher = eventlet.spawn(lambda: self.handle_read())
@ -142,16 +168,3 @@ class EventletConnection(Connection):
chunk_size = self.out_buffer_size
for i in xrange(0, len(data), chunk_size):
self._write_queue.put(data[i:i + chunk_size])
def register_watcher(self, event_type, callback, register_timeout=None):
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=[event_type]),
timeout=register_timeout)
def register_watchers(self, type_callback_dict, register_timeout=None):
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=type_callback_dict.keys()),
timeout=register_timeout)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gevent
from gevent.event import Event
import gevent.event
from gevent.queue import Queue
from gevent import select, socket
import gevent.ssl
@ -21,13 +21,13 @@ from collections import defaultdict
from functools import partial
import logging
import os
import time
from six.moves import xrange
from six.moves import range
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
from cassandra.connection import Connection, ConnectionShutdown
from cassandra.protocol import RegisterMessage
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
log = logging.getLogger(__name__)
@ -51,15 +51,39 @@ class GeventConnection(Connection):
_socket_impl = gevent.socket
_ssl_impl = gevent.ssl
_timers = None
_timeout_watcher = None
_new_timer = None
@classmethod
def initialize_reactor(cls):
if not cls._timers:
cls._timers = TimerManager()
cls._timeout_watcher = gevent.spawn(cls.service_timeouts)
cls._new_timer = gevent.event.Event()
@classmethod
def create_timer(cls, timeout, callback):
timer = Timer(timeout, callback)
cls._timers.add_timer(timer)
cls._new_timer.set()
return timer
@classmethod
def service_timeouts(cls):
timer_manager = cls._timers
timer_event = cls._new_timer
while True:
next_end = timer_manager.service_timeouts()
sleep_time = max(next_end - time.time(), 0) if next_end else 10000
timer_event.wait(sleep_time)
timer_event.clear()
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self._write_queue = Queue()
self._callbacks = {}
self._push_watchers = defaultdict(set)
self._connect_socket()
self._read_watcher = gevent.spawn(self.handle_read)
@ -142,18 +166,5 @@ class GeventConnection(Connection):
def push(self, data):
chunk_size = self.out_buffer_size
for i in xrange(0, len(data), chunk_size):
for i in range(0, len(data), chunk_size):
self._write_queue.put(data[i:i + chunk_size])
def register_watcher(self, event_type, callback, register_timeout=None):
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=[event_type]),
timeout=register_timeout)
def register_watchers(self, type_callback_dict, register_timeout=None):
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=type_callback_dict.keys()),
timeout=register_timeout)

View File

@ -17,13 +17,14 @@ from functools import partial
import logging
import os
import socket
from threading import Event, Lock, Thread
import ssl
from threading import Lock, Thread
import weakref
from six.moves import xrange
from six.moves import range
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING, ssl
from cassandra.protocol import RegisterMessage
from cassandra.connection import (Connection, ConnectionShutdown,
NONBLOCKING, Timer, TimerManager)
try:
import cassandra.io.libevwrapper as libev
except ImportError:
@ -44,7 +45,6 @@ def _cleanup(loop_weakref):
loop = loop_weakref()
except ReferenceError:
return
loop._cleanup()
@ -79,10 +79,10 @@ class LibevLoop(object):
self._loop.unref()
self._preparer.start()
atexit.register(partial(_cleanup, weakref.ref(self)))
self._timers = TimerManager()
self._loop_timer = libev.Timer(self._loop, self._on_loop_timer)
def notify(self):
self._notifier.send()
atexit.register(partial(_cleanup, weakref.ref(self)))
def maybe_start(self):
should_start = False
@ -127,6 +127,7 @@ class LibevLoop(object):
conn._read_watcher.stop()
del conn._read_watcher
self.notify() # wake the timer watcher
log.debug("Waiting for event loop thread to join...")
self._thread.join(timeout=1.0)
if self._thread.is_alive():
@ -137,6 +138,24 @@ class LibevLoop(object):
log.debug("Event loop thread was joined")
self._loop = None
def add_timer(self, timer):
self._timers.add_timer(timer)
self._notifier.send() # wake up in case this timer is earlier
def _update_timer(self):
if not self._shutdown:
self._timers.service_timeouts()
offset = self._timers.next_offset or 100000 # none pending; will be updated again when something new happens
self._loop_timer.start(offset)
else:
self._loop_timer.stop()
def _on_loop_timer(self):
self._timers.service_timeouts()
def notify(self):
self._notifier.send()
def connection_created(self, conn):
with self._conn_set_lock:
new_live_conns = self._live_conns.copy()
@ -199,6 +218,9 @@ class LibevLoop(object):
changed = True
# TODO: update to do connection management, timer updates through dedicated async 'notifier' callbacks
self._update_timer()
if changed:
self._notifier.send()
@ -229,12 +251,15 @@ class LibevConnection(Connection):
cls._libevloop._cleanup()
cls._libevloop = None
@classmethod
def create_timer(cls, timeout, callback):
timer = Timer(timeout, callback)
cls._libevloop.add_timer(timer)
return timer
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self._callbacks = {}
self.deque = deque()
self._deque_lock = Lock()
self._connect_socket()
@ -332,7 +357,7 @@ class LibevConnection(Connection):
sabs = self.out_buffer_size
if len(data) > sabs:
chunks = []
for i in xrange(0, len(data), sabs):
for i in range(0, len(data), sabs):
chunks.append(data[i:i + sabs])
else:
chunks = [data]
@ -340,14 +365,3 @@ class LibevConnection(Connection):
with self._deque_lock:
self.deque.extend(chunks)
self._libevloop.notify()
def register_watcher(self, event_type, callback, register_timeout=None):
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=[event_type]), timeout=register_timeout)
def register_watchers(self, type_callback_dict, register_timeout=None):
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=type_callback_dict.keys()), timeout=register_timeout)

View File

@ -451,6 +451,131 @@ static PyTypeObject libevwrapper_PrepareType = {
(initproc)Prepare_init, /* tp_init */
};
typedef struct libevwrapper_Timer {
PyObject_HEAD
struct ev_timer timer;
struct libevwrapper_Loop *loop;
PyObject *callback;
} libevwrapper_Timer;
static void
Timer_dealloc(libevwrapper_Timer *self) {
Py_XDECREF(self->loop);
Py_XDECREF(self->callback);
Py_TYPE(self)->tp_free((PyObject *)self);
}
static void timer_callback(struct ev_loop *loop, ev_timer *watcher, int revents) {
libevwrapper_Timer *self = watcher->data;
PyObject *result = NULL;
PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
result = PyObject_CallFunction(self->callback, NULL);
if (!result) {
PyErr_WriteUnraisable(self->callback);
}
Py_XDECREF(result);
PyGILState_Release(gstate);
}
static int
Timer_init(libevwrapper_Timer *self, PyObject *args, PyObject *kwds) {
PyObject *callback;
PyObject *loop;
if (!PyArg_ParseTuple(args, "OO", &loop, &callback)) {
return -1;
}
if (loop) {
Py_INCREF(loop);
self->loop = (libevwrapper_Loop *)loop;
} else {
return -1;
}
if (callback) {
if (!PyCallable_Check(callback)) {
PyErr_SetString(PyExc_TypeError, "callback parameter must be callable");
Py_XDECREF(loop);
return -1;
}
Py_INCREF(callback);
self->callback = callback;
}
ev_init(&self->timer, timer_callback);
self->timer.data = self;
return 0;
}
static PyObject *
Timer_start(libevwrapper_Timer *self, PyObject *args) {
double timeout;
if (!PyArg_ParseTuple(args, "d", &timeout)) {
return NULL;
}
/* some tiny non-zero number to avoid zero, and
make it run immediately for negative timeouts */
self->timer.repeat = fmax(timeout, 0.000000001);
ev_timer_again(self->loop->loop, &self->timer);
Py_RETURN_NONE;
}
static PyObject *
Timer_stop(libevwrapper_Timer *self, PyObject *args) {
ev_timer_stop(self->loop->loop, &self->timer);
Py_RETURN_NONE;
}
static PyMethodDef Timer_methods[] = {
{"start", (PyCFunction)Timer_start, METH_VARARGS, "Start the Timer watcher"},
{"stop", (PyCFunction)Timer_stop, METH_NOARGS, "Stop the Timer watcher"},
{NULL} /* Sentinal */
};
static PyTypeObject libevwrapper_TimerType = {
PyVarObject_HEAD_INIT(NULL, 0)
"cassandra.io.libevwrapper.Timer", /*tp_name*/
sizeof(libevwrapper_Timer), /*tp_basicsize*/
0, /*tp_itemsize*/
(destructor)Timer_dealloc, /*tp_dealloc*/
0, /*tp_print*/
0, /*tp_getattr*/
0, /*tp_setattr*/
0, /*tp_compare*/
0, /*tp_repr*/
0, /*tp_as_number*/
0, /*tp_as_sequence*/
0, /*tp_as_mapping*/
0, /*tp_hash */
0, /*tp_call*/
0, /*tp_str*/
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
"Timer objects", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
Timer_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)Timer_init, /* tp_init */
};
static PyMethodDef module_methods[] = {
{NULL} /* Sentinal */
};
@ -500,6 +625,10 @@ initlibevwrapper(void)
if (PyType_Ready(&libevwrapper_AsyncType) < 0)
INITERROR;
libevwrapper_TimerType.tp_new = PyType_GenericNew;
if (PyType_Ready(&libevwrapper_TimerType) < 0)
INITERROR;
# if PY_MAJOR_VERSION >= 3
module = PyModule_Create(&moduledef);
# else
@ -532,6 +661,10 @@ initlibevwrapper(void)
if (PyModule_AddObject(module, "Async", (PyObject *)&libevwrapper_AsyncType) == -1)
INITERROR;
Py_INCREF(&libevwrapper_TimerType);
if (PyModule_AddObject(module, "Timer", (PyObject *)&libevwrapper_TimerType) == -1)
INITERROR;
if (!PyEval_ThreadsInitialized()) {
PyEval_InitThreads();
}

View File

@ -15,15 +15,15 @@
Module that implements an event loop based on twisted
( https://twistedmatrix.com ).
"""
from twisted.internet import reactor, protocol
from threading import Event, Thread, Lock
import atexit
from functools import partial
import logging
from threading import Thread, Lock
import time
from twisted.internet import reactor, protocol
import weakref
import atexit
from cassandra.connection import Connection, ConnectionShutdown
from cassandra.protocol import RegisterMessage
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
log = logging.getLogger(__name__)
@ -108,9 +108,12 @@ class TwistedLoop(object):
_lock = None
_thread = None
_timeout_task = None
_timeout = None
def __init__(self):
self._lock = Lock()
self._timers = TimerManager()
def maybe_start(self):
with self._lock:
@ -132,6 +135,27 @@ class TwistedLoop(object):
"Cluster.shutdown() to avoid this.")
log.debug("Event loop thread was joined")
def add_timer(self, timer):
self._timers.add_timer(timer)
# callFromThread to schedule from the loop thread, where
# the timeout task can safely be modified
reactor.callFromThread(self._schedule_timeout, timer.end)
def _schedule_timeout(self, next_timeout):
if next_timeout:
delay = max(next_timeout - time.time(), 0)
if self._timeout_task and self._timeout_task.active():
if next_timeout < self._timeout:
self._timeout_task.reset(delay)
self._timeout = next_timeout
else:
self._timeout_task = reactor.callLater(delay, self._on_loop_timer)
self._timeout = next_timeout
def _on_loop_timer(self):
self._timers.service_timeouts()
self._schedule_timeout(self._timers.next_timeout)
class TwistedConnection(Connection):
"""
@ -146,6 +170,12 @@ class TwistedConnection(Connection):
if not cls._loop:
cls._loop = TwistedLoop()
@classmethod
def create_timer(cls, timeout, callback):
timer = Timer(timeout, callback)
cls._loop.add_timer(timer)
return timer
def __init__(self, *args, **kwargs):
"""
Initialization method.
@ -157,11 +187,9 @@ class TwistedConnection(Connection):
"""
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self.is_closed = True
self.connector = None
self._callbacks = {}
reactor.callFromThread(self.add_connection)
self._loop.maybe_start()
@ -218,22 +246,3 @@ class TwistedConnection(Connection):
the event loop when it gets the chance.
"""
reactor.callFromThread(self.connector.transport.write, data)
def register_watcher(self, event_type, callback, register_timeout=None):
"""
Register a callback for a given event type.
"""
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=[event_type]),
timeout=register_timeout)
def register_watchers(self, type_callback_dict, register_timeout=None):
"""
Register multiple callback/event type pairs, expressed as a dict.
"""
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=type_callback_dict.keys()),
timeout=register_timeout)

View File

@ -917,14 +917,14 @@ class QueryTrace(object):
break
def _execute(self, query, parameters, time_spent, max_wait):
timeout = (max_wait - time_spent) if max_wait is not None else None
future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout)
# in case the user switched the row factory, set it to namedtuple for this query
future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None)
future.row_factory = named_tuple_factory
future.send_request()
timeout = (max_wait - time_spent) if max_wait is not None else None
try:
return future.result(timeout=timeout)
return future.result()
except OperationTimedOut:
raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))

View File

@ -48,7 +48,7 @@ Getting Started
from cassandra.cqlengine import columns
from cassandra.cqlengine import connection
from datetime import datetime
from cassandra.cqlengine.management import create_keyspace, sync_table
from cassandra.cqlengine.management import sync_table
from cassandra.cqlengine.models import Model
#first, define a model

View File

@ -244,10 +244,7 @@ class ConnectionTest(unittest.TestCase):
Ensure the following methods throw NIE's. If not, come back and test them.
"""
c = self.make_connection()
self.assertRaises(NotImplementedError, c.close)
self.assertRaises(NotImplementedError, c.register_watcher, None, None)
self.assertRaises(NotImplementedError, c.register_watchers, None)
def test_set_keyspace_blocking(self):
c = self.make_connection()

View File

@ -47,7 +47,7 @@ class ResponseFutureTests(unittest.TestCase):
def make_response_future(self, session):
query = SimpleStatement("SELECT * FROM foo")
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
return ResponseFuture(session, message, query)
return ResponseFuture(session, message, query, 1)
def make_mock_response(self, results):
return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=results, paging_state=None)
@ -122,7 +122,7 @@ class ResponseFutureTests(unittest.TestCase):
query.retry_policy.on_read_timeout.return_value = (RetryPolicy.RETHROW, None)
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
rf = ResponseFuture(session, message, query)
rf = ResponseFuture(session, message, query, 1)
rf.send_request()
result = Mock(spec=ReadTimeoutErrorMessage, info={})
@ -137,7 +137,7 @@ class ResponseFutureTests(unittest.TestCase):
query.retry_policy.on_write_timeout.return_value = (RetryPolicy.RETHROW, None)
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
rf = ResponseFuture(session, message, query)
rf = ResponseFuture(session, message, query, 1)
rf.send_request()
result = Mock(spec=WriteTimeoutErrorMessage, info={})
@ -151,7 +151,7 @@ class ResponseFutureTests(unittest.TestCase):
query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None)
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
rf = ResponseFuture(session, message, query)
rf = ResponseFuture(session, message, query, 1)
rf.send_request()
result = Mock(spec=UnavailableErrorMessage, info={})
@ -165,7 +165,7 @@ class ResponseFutureTests(unittest.TestCase):
query.retry_policy.on_unavailable.return_value = (RetryPolicy.IGNORE, None)
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
rf = ResponseFuture(session, message, query)
rf = ResponseFuture(session, message, query, 1)
rf.send_request()
result = Mock(spec=UnavailableErrorMessage, info={})
@ -184,7 +184,7 @@ class ResponseFutureTests(unittest.TestCase):
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)
rf = ResponseFuture(session, message, query)
rf = ResponseFuture(session, message, query, 1)
rf.send_request()
rf.session._pools.get.assert_called_once_with('ip1')
@ -279,7 +279,7 @@ class ResponseFutureTests(unittest.TestCase):
session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2']
session._pools.get.return_value.is_shutdown = True
rf = ResponseFuture(session, Mock(), Mock())
rf = ResponseFuture(session, Mock(), Mock(), 1)
rf.send_request()
self.assertRaises(NoHostAvailable, rf.result)
@ -354,7 +354,7 @@ class ResponseFutureTests(unittest.TestCase):
query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None)
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
rf = ResponseFuture(session, message, query)
rf = ResponseFuture(session, message, query, 1)
rf.send_request()
rf.add_errback(self.assertIsInstance, Exception)
@ -401,7 +401,7 @@ class ResponseFutureTests(unittest.TestCase):
query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None)
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
rf = ResponseFuture(session, message, query)
rf = ResponseFuture(session, message, query, 1)
rf.send_request()
callback = Mock()
@ -431,7 +431,7 @@ class ResponseFutureTests(unittest.TestCase):
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
# test errback
rf = ResponseFuture(session, message, query)
rf = ResponseFuture(session, message, query, 1)
rf.send_request()
rf.add_callbacks(
@ -443,7 +443,7 @@ class ResponseFutureTests(unittest.TestCase):
self.assertRaises(Exception, rf.result)
# test callback
rf = ResponseFuture(session, message, query)
rf = ResponseFuture(session, message, query, 1)
rf.send_request()
callback = Mock()