148
cassandra/concurrent.py
Normal file
148
cassandra/concurrent.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import sys
|
||||
|
||||
from itertools import count, cycle
|
||||
from threading import Event
|
||||
|
||||
|
||||
def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True):
|
||||
"""
|
||||
Executes a sequence of (statement, parameters) tuples concurrently. Each
|
||||
``parameters`` item must be a sequence or :const:`None`.
|
||||
|
||||
A sequence of ``(success, result_or_exc)`` tuples is returned in the same
|
||||
order that the statements were passed in. If ``success`` if :const:`False`,
|
||||
there was an error executing the statement, and ``result_or_exc`` will be
|
||||
an :class:`Exception`. If ``success`` is :const:`True`, ``result_or_exc``
|
||||
will be the query result.
|
||||
|
||||
If `raise_on_first_error` is left as :const:`True`, execution will stop
|
||||
after the first failed statement and the corresponding exception will be
|
||||
raised.
|
||||
|
||||
The `concurrency` parameter controls how many statements will be executed
|
||||
concurrently. It is recommended that this be kept below the number of
|
||||
core connections per host times the number of connected hosts (see
|
||||
:meth:`.Cluster.set_core_connections_per_host`). If that amount is exceeded,
|
||||
the event loop thread may attempt to block on new connection creation,
|
||||
substantially impacting throughput.
|
||||
|
||||
Example usage::
|
||||
|
||||
select_statement = session.prepare("SELECT * FROM users WHERE id=?")
|
||||
|
||||
statements_and_params = []
|
||||
for user_id in user_ids:
|
||||
statatements_and_params.append(
|
||||
(select_statement, user_id))
|
||||
|
||||
results = execute_concurrent(
|
||||
session, statements_and_params, raise_on_first_error=False)
|
||||
|
||||
for (success, result) in results:
|
||||
if not success:
|
||||
handle_error(result) # result will be an Exception
|
||||
else:
|
||||
process_user(result[0]) # result will be a list of rows
|
||||
|
||||
"""
|
||||
if concurrency <= 0:
|
||||
raise ValueError("concurrency must be greater than 0")
|
||||
|
||||
if not statements_and_parameters:
|
||||
return []
|
||||
|
||||
event = Event()
|
||||
first_error = [] if raise_on_first_error else None
|
||||
to_execute = len(statements_and_parameters) # TODO handle iterators/generators
|
||||
results = [None] * to_execute
|
||||
num_finished = count(start=1)
|
||||
statements = enumerate(iter(statements_and_parameters))
|
||||
for i in xrange(min(concurrency, len(statements_and_parameters))):
|
||||
_execute_next(_sentinel, i, event, session, statements, results, num_finished, to_execute, first_error)
|
||||
|
||||
event.wait()
|
||||
if first_error:
|
||||
raise first_error[0]
|
||||
else:
|
||||
return results
|
||||
|
||||
|
||||
def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs):
|
||||
"""
|
||||
Like :meth:`~.execute_concurrent`, but takes a single statement and a
|
||||
sequence of parameters. Each item in ``parameters`` should be a sequence
|
||||
or :const:`None`.
|
||||
|
||||
Example usage::
|
||||
|
||||
statement = session.prepare("INSERT INTO mytable (a, b) VALUES (1, ?)")
|
||||
parameters = [(x,) for x in range(1000)]
|
||||
execute_concurrent_with_args(session, statement, parameters)
|
||||
"""
|
||||
return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs)
|
||||
|
||||
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def _handle_error(error, result_index, event, session, statements, results, num_finished, to_execute, first_error):
|
||||
if first_error is not None:
|
||||
first_error.append(error)
|
||||
event.set()
|
||||
return
|
||||
else:
|
||||
results[result_index] = (False, error)
|
||||
if num_finished.next() >= to_execute:
|
||||
event.set()
|
||||
return
|
||||
|
||||
try:
|
||||
(next_index, (statement, params)) = statements.next()
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
args = (next_index, event, session, statements, results, num_finished, to_execute, first_error)
|
||||
try:
|
||||
session.execute_async(statement, params).add_callbacks(
|
||||
callback=_execute_next, callback_args=args,
|
||||
errback=_handle_error, errback_args=args)
|
||||
except Exception as exc:
|
||||
if first_error is not None:
|
||||
first_error.append(sys.exc_info())
|
||||
event.set()
|
||||
return
|
||||
else:
|
||||
results[next_index] = (False, exc)
|
||||
if num_finished.next() >= to_execute:
|
||||
event.set()
|
||||
return
|
||||
|
||||
|
||||
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
|
||||
if result is not _sentinel:
|
||||
results[result_index] = (True, result)
|
||||
finished = num_finished.next()
|
||||
if finished >= to_execute:
|
||||
event.set()
|
||||
return
|
||||
|
||||
try:
|
||||
(next_index, (statement, params)) = statements.next()
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
args = (next_index, event, session, statements, results, num_finished, to_execute, first_error)
|
||||
try:
|
||||
session.execute_async(statement, params).add_callbacks(
|
||||
callback=_execute_next, callback_args=args,
|
||||
errback=_handle_error, errback_args=args)
|
||||
except Exception as exc:
|
||||
if first_error is not None:
|
||||
first_error.append(sys.exc_info())
|
||||
event.set()
|
||||
return
|
||||
else:
|
||||
results[next_index] = (False, exc)
|
||||
if num_finished.next() >= to_execute:
|
||||
event.set()
|
||||
return
|
113
tests/integration/standard/test_concurrent.py
Normal file
113
tests/integration/standard/test_concurrent.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from itertools import cycle
|
||||
|
||||
from cassandra import InvalidRequest
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.concurrent import (execute_concurrent,
|
||||
execute_concurrent_with_args)
|
||||
from cassandra.policies import HostDistance
|
||||
from cassandra.query import tuple_factory
|
||||
|
||||
|
||||
class ClusterTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
||||
self.session = self.cluster.connect()
|
||||
self.session.row_factory = tuple_factory
|
||||
|
||||
def test_execute_concurrent(self):
|
||||
for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
|
||||
# write
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(num_statements)]
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters))
|
||||
self.assertEqual(num_statements, len(results))
|
||||
self.assertEqual([(True, None)] * num_statements, results)
|
||||
|
||||
# read
|
||||
statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", ))
|
||||
parameters = [(i, ) for i in range(num_statements)]
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters))
|
||||
self.assertEqual(num_statements, len(results))
|
||||
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
|
||||
|
||||
def test_execute_concurrent_with_args(self):
|
||||
for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
|
||||
statement = "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"
|
||||
parameters = [(i, i) for i in range(num_statements)]
|
||||
|
||||
results = execute_concurrent_with_args(self.session, statement, parameters)
|
||||
self.assertEqual(num_statements, len(results))
|
||||
self.assertEqual([(True, None)] * num_statements, results)
|
||||
|
||||
# read
|
||||
statement = "SELECT v FROM test3rf.test WHERE k=%s"
|
||||
parameters = [(i, ) for i in range(num_statements)]
|
||||
|
||||
results = execute_concurrent_with_args(self.session, statement, parameters)
|
||||
self.assertEqual(num_statements, len(results))
|
||||
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
|
||||
|
||||
def test_first_failure(self):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(100)]
|
||||
|
||||
# we'll get an error back from the server
|
||||
parameters[57] = ('efefef', 'awefawefawef')
|
||||
|
||||
self.assertRaises(
|
||||
InvalidRequest,
|
||||
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
|
||||
|
||||
def test_first_failure_client_side(self):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(100)]
|
||||
|
||||
# the driver will raise an error when binding the params
|
||||
parameters[57] = 1
|
||||
|
||||
self.assertRaises(
|
||||
TypeError,
|
||||
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
|
||||
|
||||
def test_no_raise_on_first_failure(self):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(100)]
|
||||
|
||||
# we'll get an error back from the server
|
||||
parameters[57] = ('efefef', 'awefawefawef')
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
|
||||
for i, (success, result) in enumerate(results):
|
||||
if i == 57:
|
||||
self.assertFalse(success)
|
||||
self.assertIsInstance(result, InvalidRequest)
|
||||
else:
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(None, result)
|
||||
|
||||
def test_no_raise_on_first_failure_client_side(self):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(100)]
|
||||
|
||||
# the driver will raise an error when binding the params
|
||||
parameters[57] = i
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
|
||||
for i, (success, result) in enumerate(results):
|
||||
if i == 57:
|
||||
self.assertFalse(success)
|
||||
self.assertIsInstance(result, TypeError)
|
||||
else:
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(None, result)
|
Reference in New Issue
Block a user