Merge branch 'py3k' into 2.0
Conflicts: cassandra/cluster.py cassandra/encoder.py cassandra/marshal.py cassandra/pool.py setup.py tests/integration/long/test_large_data.py tests/integration/long/utils.py tests/integration/standard/test_metadata.py tests/integration/standard/test_prepared_statements.py tests/unit/io/test_asyncorereactor.py tests/unit/test_connection.py tests/unit/test_types.py
This commit is contained in:
@@ -4,11 +4,13 @@ env:
|
|||||||
- TOX_ENV=py26
|
- TOX_ENV=py26
|
||||||
- TOX_ENV=py27
|
- TOX_ENV=py27
|
||||||
- TOX_ENV=pypy
|
- TOX_ENV=pypy
|
||||||
|
- TOX_ENV=py33
|
||||||
|
|
||||||
before_install:
|
before_install:
|
||||||
- sudo apt-get update -y
|
- sudo apt-get update -y
|
||||||
- sudo apt-get install -y build-essential python-dev
|
- sudo apt-get install -y build-essential python-dev
|
||||||
- sudo apt-get install -y libev4 libev-dev
|
- sudo apt-get install -y libev4 libev-dev
|
||||||
|
|
||||||
install:
|
install:
|
||||||
- pip install tox
|
- pip install tox
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ try:
|
|||||||
from cassandra.io.libevreactor import LibevConnection
|
from cassandra.io.libevreactor import LibevConnection
|
||||||
have_libev = True
|
have_libev = True
|
||||||
supported_reactors.append(LibevConnection)
|
supported_reactors.append(LibevConnection)
|
||||||
except ImportError, exc:
|
except ImportError as exc:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
KEYSPACE = "testkeyspace"
|
KEYSPACE = "testkeyspace"
|
||||||
@@ -104,7 +104,7 @@ def benchmark(thread_class):
|
|||||||
""".format(table=TABLE))
|
""".format(table=TABLE))
|
||||||
values = ('key', 'a', 'b')
|
values = ('key', 'a', 'b')
|
||||||
|
|
||||||
per_thread = options.num_ops / options.threads
|
per_thread = options.num_ops // options.threads
|
||||||
threads = []
|
threads = []
|
||||||
|
|
||||||
log.debug("Beginning inserts...")
|
log.debug("Beginning inserts...")
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from itertools import count
|
|||||||
from threading import Event
|
from threading import Event
|
||||||
|
|
||||||
from base import benchmark, BenchmarkThread
|
from base import benchmark, BenchmarkThread
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,17 +38,17 @@ class Runner(BenchmarkThread):
|
|||||||
if previous_result is not sentinel:
|
if previous_result is not sentinel:
|
||||||
if isinstance(previous_result, BaseException):
|
if isinstance(previous_result, BaseException):
|
||||||
log.error("Error on insert: %r", previous_result)
|
log.error("Error on insert: %r", previous_result)
|
||||||
if self.num_finished.next() >= self.num_queries:
|
if next(self.num_finished) >= self.num_queries:
|
||||||
self.event.set()
|
self.event.set()
|
||||||
|
|
||||||
if self.num_started.next() <= self.num_queries:
|
if next(self.num_started) <= self.num_queries:
|
||||||
future = self.session.execute_async(self.query, self.values)
|
future = self.session.execute_async(self.query, self.values)
|
||||||
future.add_callbacks(self.insert_next, self.insert_next)
|
future.add_callbacks(self.insert_next, self.insert_next)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.start_profile()
|
self.start_profile()
|
||||||
|
|
||||||
for _ in xrange(min(120, self.num_queries)):
|
for _ in range(min(120, self.num_queries)):
|
||||||
self.insert_next()
|
self.insert_next()
|
||||||
|
|
||||||
self.event.wait()
|
self.event.wait()
|
||||||
|
|||||||
@@ -13,16 +13,16 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import Queue
|
|
||||||
|
|
||||||
from base import benchmark, BenchmarkThread
|
from base import benchmark, BenchmarkThread
|
||||||
|
from six.moves import queue
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Runner(BenchmarkThread):
|
class Runner(BenchmarkThread):
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
futures = Queue.Queue(maxsize=121)
|
futures = queue.Queue(maxsize=121)
|
||||||
|
|
||||||
self.start_profile()
|
self.start_profile()
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ class Runner(BenchmarkThread):
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
futures.get_nowait().result()
|
futures.get_nowait().result()
|
||||||
except Queue.Empty:
|
except queue.Empty:
|
||||||
break
|
break
|
||||||
|
|
||||||
future = self.session.execute_async(self.query, self.values)
|
future = self.session.execute_async(self.query, self.values)
|
||||||
@@ -41,7 +41,7 @@ class Runner(BenchmarkThread):
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
futures.get_nowait().result()
|
futures.get_nowait().result()
|
||||||
except Queue.Empty:
|
except queue.Empty:
|
||||||
break
|
break
|
||||||
|
|
||||||
self.finish_profile()
|
self.finish_profile()
|
||||||
|
|||||||
@@ -13,16 +13,16 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import Queue
|
|
||||||
|
|
||||||
from base import benchmark, BenchmarkThread
|
from base import benchmark, BenchmarkThread
|
||||||
|
from six.moves import queue
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Runner(BenchmarkThread):
|
class Runner(BenchmarkThread):
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
futures = Queue.Queue(maxsize=121)
|
futures = queue.Queue(maxsize=121)
|
||||||
|
|
||||||
self.start_profile()
|
self.start_profile()
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ class Runner(BenchmarkThread):
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
futures.get_nowait().result()
|
futures.get_nowait().result()
|
||||||
except Queue.Empty:
|
except queue.Empty:
|
||||||
break
|
break
|
||||||
|
|
||||||
self.finish_profile
|
self.finish_profile
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class Runner(BenchmarkThread):
|
|||||||
|
|
||||||
self.start_profile()
|
self.start_profile()
|
||||||
|
|
||||||
for i in range(self.num_queries):
|
for _ in range(self.num_queries):
|
||||||
future = self.session.execute_async(self.query, self.values)
|
future = self.session.execute_async(self.query, self.values)
|
||||||
futures.append(future)
|
futures.append(future)
|
||||||
|
|
||||||
|
|||||||
@@ -13,13 +13,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from base import benchmark, BenchmarkThread
|
from base import benchmark, BenchmarkThread
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
|
||||||
class Runner(BenchmarkThread):
|
class Runner(BenchmarkThread):
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.start_profile()
|
self.start_profile()
|
||||||
|
|
||||||
for i in xrange(self.num_queries):
|
for _ in range(self.num_queries):
|
||||||
self.session.execute(self.query, self.values)
|
self.session.execute(self.query, self.values)
|
||||||
|
|
||||||
self.finish_profile()
|
self.finish_profile()
|
||||||
|
|||||||
@@ -26,7 +26,11 @@ import socket
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from threading import Lock, RLock, Thread, Event
|
from threading import Lock, RLock, Thread, Event
|
||||||
import Queue
|
|
||||||
|
import six
|
||||||
|
from six.moves import range
|
||||||
|
from six.moves import queue as Queue
|
||||||
|
|
||||||
import weakref
|
import weakref
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
try:
|
try:
|
||||||
@@ -696,7 +700,7 @@ class Cluster(object):
|
|||||||
|
|
||||||
host.set_down()
|
host.set_down()
|
||||||
|
|
||||||
log.warn("Host %s has been marked down", host)
|
log.warning("Host %s has been marked down", host)
|
||||||
|
|
||||||
self.load_balancing_policy.on_down(host)
|
self.load_balancing_policy.on_down(host)
|
||||||
self.control_connection.on_down(host)
|
self.control_connection.on_down(host)
|
||||||
@@ -742,7 +746,7 @@ class Cluster(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if not all(futures_results):
|
if not all(futures_results):
|
||||||
log.warn("Connection pool could not be created, not marking node %s up", host)
|
log.warning("Connection pool could not be created, not marking node %s up", host)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._finalize_add(host)
|
self._finalize_add(host)
|
||||||
@@ -867,7 +871,7 @@ class Cluster(object):
|
|||||||
# prepare 10 statements at a time
|
# prepare 10 statements at a time
|
||||||
ks_statements = list(ks_statements)
|
ks_statements = list(ks_statements)
|
||||||
chunks = []
|
chunks = []
|
||||||
for i in xrange(0, len(ks_statements), 10):
|
for i in range(0, len(ks_statements), 10):
|
||||||
chunks.append(ks_statements[i:i + 10])
|
chunks.append(ks_statements[i:i + 10])
|
||||||
|
|
||||||
for ks_chunk in chunks:
|
for ks_chunk in chunks:
|
||||||
@@ -882,9 +886,9 @@ class Cluster(object):
|
|||||||
|
|
||||||
log.debug("Done preparing all known prepared statements against host %s", host)
|
log.debug("Done preparing all known prepared statements against host %s", host)
|
||||||
except OperationTimedOut as timeout:
|
except OperationTimedOut as timeout:
|
||||||
log.warn("Timed out trying to prepare all statements on host %s: %s", host, timeout)
|
log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout)
|
||||||
except (ConnectionException, socket.error) as exc:
|
except (ConnectionException, socket.error) as exc:
|
||||||
log.warn("Error trying to prepare all statements on host %s: %r", host, exc)
|
log.warning("Error trying to prepare all statements on host %s: %r", host, exc)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception("Error trying to prepare all statements on host %s", host)
|
log.exception("Error trying to prepare all statements on host %s", host)
|
||||||
finally:
|
finally:
|
||||||
@@ -1088,7 +1092,7 @@ class Session(object):
|
|||||||
|
|
||||||
prepared_statement = None
|
prepared_statement = None
|
||||||
|
|
||||||
if isinstance(query, basestring):
|
if isinstance(query, six.string_types):
|
||||||
query = SimpleStatement(query)
|
query = SimpleStatement(query)
|
||||||
elif isinstance(query, PreparedStatement):
|
elif isinstance(query, PreparedStatement):
|
||||||
query = query.bind(parameters)
|
query = query.bind(parameters)
|
||||||
@@ -1235,8 +1239,8 @@ class Session(object):
|
|||||||
self.cluster.signal_connection_failure(host, conn_exc, is_host_addition)
|
self.cluster.signal_connection_failure(host, conn_exc, is_host_addition)
|
||||||
return False
|
return False
|
||||||
except Exception as conn_exc:
|
except Exception as conn_exc:
|
||||||
log.warn("Failed to create connection pool for new host %s: %s",
|
log.warning("Failed to create connection pool for new host %s: %s",
|
||||||
host, conn_exc)
|
host, conn_exc)
|
||||||
# the host itself will still be marked down, so we need to pass
|
# the host itself will still be marked down, so we need to pass
|
||||||
# a special flag to make sure the reconnector is created
|
# a special flag to make sure the reconnector is created
|
||||||
self.cluster.signal_connection_failure(
|
self.cluster.signal_connection_failure(
|
||||||
@@ -1456,11 +1460,11 @@ class ControlConnection(object):
|
|||||||
return self._try_connect(host)
|
return self._try_connect(host)
|
||||||
except ConnectionException as exc:
|
except ConnectionException as exc:
|
||||||
errors[host.address] = exc
|
errors[host.address] = exc
|
||||||
log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
|
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
|
||||||
self._cluster.signal_connection_failure(host, exc, is_host_addition=False)
|
self._cluster.signal_connection_failure(host, exc, is_host_addition=False)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
errors[host.address] = exc
|
errors[host.address] = exc
|
||||||
log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
|
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
|
||||||
|
|
||||||
raise NoHostAvailable("Unable to connect to any servers", errors)
|
raise NoHostAvailable("Unable to connect to any servers", errors)
|
||||||
|
|
||||||
@@ -1948,7 +1952,7 @@ class _Scheduler(object):
|
|||||||
def _log_if_failed(self, future):
|
def _log_if_failed(self, future):
|
||||||
exc = future.exception()
|
exc = future.exception()
|
||||||
if exc:
|
if exc:
|
||||||
log.warn(
|
log.warning(
|
||||||
"An internally scheduled tasked failed with an unhandled exception:",
|
"An internally scheduled tasked failed with an unhandled exception:",
|
||||||
exc_info=exc)
|
exc_info=exc)
|
||||||
|
|
||||||
@@ -2170,8 +2174,8 @@ class ResponseFuture(object):
|
|||||||
if self._metrics is not None:
|
if self._metrics is not None:
|
||||||
self._metrics.on_other_error()
|
self._metrics.on_other_error()
|
||||||
# need to retry against a different host here
|
# need to retry against a different host here
|
||||||
log.warn("Host %s is overloaded, retrying against a different "
|
log.warning("Host %s is overloaded, retrying against a different "
|
||||||
"host", self._current_host)
|
"host", self._current_host)
|
||||||
self._retry(reuse_connection=False, consistency_level=None)
|
self._retry(reuse_connection=False, consistency_level=None)
|
||||||
return
|
return
|
||||||
elif isinstance(response, IsBootstrappingErrorMessage):
|
elif isinstance(response, IsBootstrappingErrorMessage):
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from itertools import count, cycle
|
from itertools import count, cycle
|
||||||
|
from six.moves import xrange
|
||||||
from threading import Event
|
from threading import Event
|
||||||
|
|
||||||
|
|
||||||
@@ -105,7 +106,7 @@ def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs
|
|||||||
parameters = [(x,) for x in range(1000)]
|
parameters = [(x,) for x in range(1000)]
|
||||||
execute_concurrent_with_args(session, statement, parameters)
|
execute_concurrent_with_args(session, statement, parameters)
|
||||||
"""
|
"""
|
||||||
return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs)
|
return execute_concurrent(session, list(zip(cycle((statement,)), parameters)), *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
_sentinel = object()
|
_sentinel = object()
|
||||||
@@ -118,12 +119,12 @@ def _handle_error(error, result_index, event, session, statements, results, num_
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
results[result_index] = (False, error)
|
results[result_index] = (False, error)
|
||||||
if num_finished.next() >= to_execute:
|
if next(num_finished) >= to_execute:
|
||||||
event.set()
|
event.set()
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
(next_index, (statement, params)) = statements.next()
|
(next_index, (statement, params)) = next(statements)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -139,7 +140,7 @@ def _handle_error(error, result_index, event, session, statements, results, num_
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
results[next_index] = (False, exc)
|
results[next_index] = (False, exc)
|
||||||
if num_finished.next() >= to_execute:
|
if next(num_finished) >= to_execute:
|
||||||
event.set()
|
event.set()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -147,13 +148,13 @@ def _handle_error(error, result_index, event, session, statements, results, num_
|
|||||||
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
|
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
|
||||||
if result is not _sentinel:
|
if result is not _sentinel:
|
||||||
results[result_index] = (True, result)
|
results[result_index] = (True, result)
|
||||||
finished = num_finished.next()
|
finished = next(num_finished)
|
||||||
if finished >= to_execute:
|
if finished >= to_execute:
|
||||||
event.set()
|
event.set()
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
(next_index, (statement, params)) = statements.next()
|
(next_index, (statement, params)) = next(statements)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -169,6 +170,6 @@ def _execute_next(result, result_index, event, session, statements, results, num
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
results[next_index] = (False, exc)
|
results[next_index] = (False, exc)
|
||||||
if num_finished.next() >= to_execute:
|
if next(num_finished) >= to_execute:
|
||||||
event.set()
|
event.set()
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -19,19 +19,21 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
from threading import Event, RLock
|
from threading import Event, RLock
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
|
|
||||||
if 'gevent.monkey' in sys.modules:
|
if 'gevent.monkey' in sys.modules:
|
||||||
from gevent.queue import Queue, Empty
|
from gevent.queue import Queue, Empty
|
||||||
else:
|
else:
|
||||||
from Queue import Queue, Empty # noqa
|
from six.moves.queue import Queue, Empty # noqa
|
||||||
|
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
|
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
|
||||||
from cassandra.marshal import int8_unpack, int32_pack
|
from cassandra.marshal import int32_pack, header_unpack
|
||||||
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
||||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
StartupMessage, ErrorMessage, CredentialsMessage,
|
||||||
QueryMessage, ResultMessage, decode_response,
|
QueryMessage, ResultMessage, decode_response,
|
||||||
InvalidRequestException, SupportedMessage)
|
InvalidRequestException, SupportedMessage)
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -123,7 +125,6 @@ def defunct_on_error(f):
|
|||||||
return f(self, *args, **kwargs)
|
return f(self, *args, **kwargs)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self.defunct(exc)
|
self.defunct(exc)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@@ -181,12 +182,8 @@ class Connection(object):
|
|||||||
return
|
return
|
||||||
self.is_defunct = True
|
self.is_defunct = True
|
||||||
|
|
||||||
trace = traceback.format_exc(exc)
|
log.debug("Defuncting connection (%s) to %s:",
|
||||||
if trace != "None":
|
id(self), self.host, exc_info=exc)
|
||||||
log.debug("Defuncting connection (%s) to %s: %s\n%s",
|
|
||||||
id(self), self.host, exc, traceback.format_exc(exc))
|
|
||||||
else:
|
|
||||||
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
|
|
||||||
|
|
||||||
self.last_error = exc
|
self.last_error = exc
|
||||||
self.close()
|
self.close()
|
||||||
@@ -203,9 +200,9 @@ class Connection(object):
|
|||||||
try:
|
try:
|
||||||
cb(new_exc)
|
cb(new_exc)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.warn("Ignoring unhandled exception while erroring callbacks for a "
|
log.warning("Ignoring unhandled exception while erroring callbacks for a "
|
||||||
"failed connection (%s) to host %s:",
|
"failed connection (%s) to host %s:",
|
||||||
id(self), self.host, exc_info=True)
|
id(self), self.host, exc_info=True)
|
||||||
|
|
||||||
def handle_pushed(self, response):
|
def handle_pushed(self, response):
|
||||||
log.debug("Message pushed from server: %r", response)
|
log.debug("Message pushed from server: %r", response)
|
||||||
@@ -231,7 +228,7 @@ class Connection(object):
|
|||||||
request_id = self._id_queue.get()
|
request_id = self._id_queue.get()
|
||||||
|
|
||||||
self._callbacks[request_id] = cb
|
self._callbacks[request_id] = cb
|
||||||
self.push(msg.to_string(request_id, self.protocol_version, compression=self.compressor))
|
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
|
||||||
return request_id
|
return request_id
|
||||||
|
|
||||||
def wait_for_response(self, msg, timeout=None):
|
def wait_for_response(self, msg, timeout=None):
|
||||||
@@ -268,7 +265,7 @@ class Connection(object):
|
|||||||
return waiter.deliver(timeout)
|
return waiter.deliver(timeout)
|
||||||
except OperationTimedOut:
|
except OperationTimedOut:
|
||||||
raise
|
raise
|
||||||
except Exception, exc:
|
except Exception as exc:
|
||||||
self.defunct(exc)
|
self.defunct(exc)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -284,7 +281,7 @@ class Connection(object):
|
|||||||
|
|
||||||
@defunct_on_error
|
@defunct_on_error
|
||||||
def process_msg(self, msg, body_len):
|
def process_msg(self, msg, body_len):
|
||||||
version, flags, stream_id, opcode = map(int8_unpack, msg[:4])
|
version, flags, stream_id, opcode = header_unpack(msg[:4])
|
||||||
if stream_id < 0:
|
if stream_id < 0:
|
||||||
callback = None
|
callback = None
|
||||||
else:
|
else:
|
||||||
@@ -309,7 +306,7 @@ class Connection(object):
|
|||||||
if body_len > 0:
|
if body_len > 0:
|
||||||
body = msg[8:]
|
body = msg[8:]
|
||||||
elif body_len == 0:
|
elif body_len == 0:
|
||||||
body = ""
|
body = six.binary_type()
|
||||||
else:
|
else:
|
||||||
raise ProtocolError("Got negative body length: %r" % body_len)
|
raise ProtocolError("Got negative body length: %r" % body_len)
|
||||||
|
|
||||||
@@ -383,7 +380,7 @@ class Connection(object):
|
|||||||
locally_supported_compressions.keys(),
|
locally_supported_compressions.keys(),
|
||||||
remote_supported_compressions)
|
remote_supported_compressions)
|
||||||
else:
|
else:
|
||||||
compression_type = iter(overlap).next() # choose any
|
compression_type = next(iter(overlap)) # choose any
|
||||||
# set the decompressor here, but set the compressor only after
|
# set the decompressor here, but set the compressor only after
|
||||||
# a successful Ready message
|
# a successful Ready message
|
||||||
self._compressor, self.decompressor = \
|
self._compressor, self.decompressor = \
|
||||||
|
|||||||
@@ -36,10 +36,8 @@ from datetime import datetime
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
try:
|
import six
|
||||||
from cStringIO import StringIO
|
from six.moves import range
|
||||||
except ImportError:
|
|
||||||
from StringIO import StringIO # NOQA
|
|
||||||
|
|
||||||
from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpack,
|
from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpack,
|
||||||
int32_pack, int32_unpack, int64_pack, int64_unpack,
|
int32_pack, int32_unpack, int64_pack, int64_unpack,
|
||||||
@@ -49,7 +47,11 @@ from cassandra.util import OrderedDict
|
|||||||
|
|
||||||
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
|
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
|
||||||
|
|
||||||
_number_types = frozenset((int, long, float))
|
if six.PY3:
|
||||||
|
_number_types = frozenset((int, float))
|
||||||
|
long = int
|
||||||
|
else:
|
||||||
|
_number_types = frozenset((int, long, float))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from blist import sortedset
|
from blist import sortedset
|
||||||
@@ -69,7 +71,7 @@ def trim_if_startswith(s, prefix):
|
|||||||
|
|
||||||
|
|
||||||
def unix_time_from_uuid1(u):
|
def unix_time_from_uuid1(u):
|
||||||
return (u.get_time() - 0x01B21DD213814000) / 10000000.0
|
return (u.time - 0x01B21DD213814000) / 10000000.0
|
||||||
|
|
||||||
_casstypes = {}
|
_casstypes = {}
|
||||||
|
|
||||||
@@ -177,8 +179,8 @@ class EmptyValue(object):
|
|||||||
EMPTY = EmptyValue()
|
EMPTY = EmptyValue()
|
||||||
|
|
||||||
|
|
||||||
|
@six.add_metaclass(CassandraTypeType)
|
||||||
class _CassandraType(object):
|
class _CassandraType(object):
|
||||||
__metaclass__ = CassandraTypeType
|
|
||||||
subtypes = ()
|
subtypes = ()
|
||||||
num_subtypes = 0
|
num_subtypes = 0
|
||||||
empty_binary_ok = False
|
empty_binary_ok = False
|
||||||
@@ -199,9 +201,8 @@ class _CassandraType(object):
|
|||||||
def __init__(self, val):
|
def __init__(self, val):
|
||||||
self.val = self.validate(val)
|
self.val = self.validate(val)
|
||||||
|
|
||||||
def __str__(self):
|
def __repr__(self):
|
||||||
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
|
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
|
||||||
__repr__ = __str__
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate(val):
|
def validate(val):
|
||||||
@@ -221,7 +222,7 @@ class _CassandraType(object):
|
|||||||
"""
|
"""
|
||||||
if byts is None:
|
if byts is None:
|
||||||
return None
|
return None
|
||||||
elif byts == '' and not cls.empty_binary_ok:
|
elif len(byts) == 0 and not cls.empty_binary_ok:
|
||||||
return EMPTY if cls.support_empty_values else None
|
return EMPTY if cls.support_empty_values else None
|
||||||
return cls.deserialize(byts)
|
return cls.deserialize(byts)
|
||||||
|
|
||||||
@@ -232,7 +233,7 @@ class _CassandraType(object):
|
|||||||
more information. This method differs in that if None is passed in,
|
more information. This method differs in that if None is passed in,
|
||||||
the result is the empty string.
|
the result is the empty string.
|
||||||
"""
|
"""
|
||||||
return '' if val is None else cls.serialize(val)
|
return b'' if val is None else cls.serialize(val)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def deserialize(byts):
|
def deserialize(byts):
|
||||||
@@ -293,7 +294,8 @@ class _CassandraType(object):
|
|||||||
if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes:
|
if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes:
|
||||||
raise ValueError("%s types require %d subtypes (%d given)"
|
raise ValueError("%s types require %d subtypes (%d given)"
|
||||||
% (cls.typename, cls.num_subtypes, len(subtypes)))
|
% (cls.typename, cls.num_subtypes, len(subtypes)))
|
||||||
newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
|
# newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
|
||||||
|
newname = cls.cass_parameterized_type_with(subtypes)
|
||||||
return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname})
|
return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -324,10 +326,16 @@ class _UnrecognizedType(_CassandraType):
|
|||||||
num_subtypes = 'UNKNOWN'
|
num_subtypes = 'UNKNOWN'
|
||||||
|
|
||||||
|
|
||||||
def mkUnrecognizedType(casstypename):
|
if six.PY3:
|
||||||
return CassandraTypeType(casstypename.encode('utf8'),
|
def mkUnrecognizedType(casstypename):
|
||||||
(_UnrecognizedType,),
|
return CassandraTypeType(casstypename,
|
||||||
{'typename': "'%s'" % casstypename})
|
(_UnrecognizedType,),
|
||||||
|
{'typename': "'%s'" % casstypename})
|
||||||
|
else:
|
||||||
|
def mkUnrecognizedType(casstypename): # noqa
|
||||||
|
return CassandraTypeType(casstypename.encode('utf8'),
|
||||||
|
(_UnrecognizedType,),
|
||||||
|
{'typename': "'%s'" % casstypename})
|
||||||
|
|
||||||
|
|
||||||
class BytesType(_CassandraType):
|
class BytesType(_CassandraType):
|
||||||
@@ -336,11 +344,11 @@ class BytesType(_CassandraType):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate(val):
|
def validate(val):
|
||||||
return buffer(val)
|
return bytearray(val)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def serialize(val):
|
def serialize(val):
|
||||||
return str(val)
|
return six.binary_type(val)
|
||||||
|
|
||||||
|
|
||||||
class DecimalType(_CassandraType):
|
class DecimalType(_CassandraType):
|
||||||
@@ -401,9 +409,25 @@ class BooleanType(_CassandraType):
|
|||||||
return int8_pack(truth)
|
return int8_pack(truth)
|
||||||
|
|
||||||
|
|
||||||
class AsciiType(_CassandraType):
|
if six.PY2:
|
||||||
typename = 'ascii'
|
class AsciiType(_CassandraType):
|
||||||
empty_binary_ok = True
|
typename = 'ascii'
|
||||||
|
empty_binary_ok = True
|
||||||
|
else:
|
||||||
|
class AsciiType(_CassandraType):
|
||||||
|
typename = 'ascii'
|
||||||
|
empty_binary_ok = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(byts):
|
||||||
|
return byts.decode('ascii')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def serialize(var):
|
||||||
|
try:
|
||||||
|
return var.encode('ascii')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return var
|
||||||
|
|
||||||
|
|
||||||
class FloatType(_CassandraType):
|
class FloatType(_CassandraType):
|
||||||
@@ -496,7 +520,7 @@ class DateType(_CassandraType):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, date):
|
def validate(cls, date):
|
||||||
if isinstance(date, basestring):
|
if isinstance(date, six.string_types):
|
||||||
date = cls.interpret_datestring(date)
|
date = cls.interpret_datestring(date)
|
||||||
return date
|
return date
|
||||||
|
|
||||||
@@ -628,7 +652,7 @@ class _SimpleParameterizedType(_ParameterizedType):
|
|||||||
numelements = uint16_unpack(byts[:2])
|
numelements = uint16_unpack(byts[:2])
|
||||||
p = 2
|
p = 2
|
||||||
result = []
|
result = []
|
||||||
for n in xrange(numelements):
|
for _ in range(numelements):
|
||||||
itemlen = uint16_unpack(byts[p:p + 2])
|
itemlen = uint16_unpack(byts[p:p + 2])
|
||||||
p += 2
|
p += 2
|
||||||
item = byts[p:p + itemlen]
|
item = byts[p:p + itemlen]
|
||||||
@@ -638,11 +662,11 @@ class _SimpleParameterizedType(_ParameterizedType):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def serialize_safe(cls, items):
|
def serialize_safe(cls, items):
|
||||||
if isinstance(items, basestring):
|
if isinstance(items, six.string_types):
|
||||||
raise TypeError("Received a string for a type that expects a sequence")
|
raise TypeError("Received a string for a type that expects a sequence")
|
||||||
|
|
||||||
subtype, = cls.subtypes
|
subtype, = cls.subtypes
|
||||||
buf = StringIO()
|
buf = six.BytesIO()
|
||||||
buf.write(uint16_pack(len(items)))
|
buf.write(uint16_pack(len(items)))
|
||||||
for item in items:
|
for item in items:
|
||||||
itembytes = subtype.to_binary(item)
|
itembytes = subtype.to_binary(item)
|
||||||
@@ -670,7 +694,7 @@ class MapType(_ParameterizedType):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, val):
|
def validate(cls, val):
|
||||||
subkeytype, subvaltype = cls.subtypes
|
subkeytype, subvaltype = cls.subtypes
|
||||||
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in val.iteritems())
|
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in six.iteritems(val))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize_safe(cls, byts):
|
def deserialize_safe(cls, byts):
|
||||||
@@ -678,7 +702,7 @@ class MapType(_ParameterizedType):
|
|||||||
numelements = uint16_unpack(byts[:2])
|
numelements = uint16_unpack(byts[:2])
|
||||||
p = 2
|
p = 2
|
||||||
themap = OrderedDict()
|
themap = OrderedDict()
|
||||||
for n in xrange(numelements):
|
for _ in range(numelements):
|
||||||
key_len = uint16_unpack(byts[p:p + 2])
|
key_len = uint16_unpack(byts[p:p + 2])
|
||||||
p += 2
|
p += 2
|
||||||
keybytes = byts[p:p + key_len]
|
keybytes = byts[p:p + key_len]
|
||||||
@@ -695,10 +719,10 @@ class MapType(_ParameterizedType):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def serialize_safe(cls, themap):
|
def serialize_safe(cls, themap):
|
||||||
subkeytype, subvaltype = cls.subtypes
|
subkeytype, subvaltype = cls.subtypes
|
||||||
buf = StringIO()
|
buf = six.BytesIO()
|
||||||
buf.write(uint16_pack(len(themap)))
|
buf.write(uint16_pack(len(themap)))
|
||||||
try:
|
try:
|
||||||
items = themap.iteritems()
|
items = six.iteritems(themap)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise TypeError("Got a non-map object for a map value")
|
raise TypeError("Got a non-map object for a map value")
|
||||||
for key, val in items:
|
for key, val in items:
|
||||||
@@ -747,7 +771,7 @@ class ReversedType(_ParameterizedType):
|
|||||||
|
|
||||||
|
|
||||||
def is_counter_type(t):
|
def is_counter_type(t):
|
||||||
if isinstance(t, basestring):
|
if isinstance(t, six.string_types):
|
||||||
t = lookup_casstype(t)
|
t = lookup_casstype(t)
|
||||||
return issubclass(t, CounterColumnType)
|
return issubclass(t, CounterColumnType)
|
||||||
|
|
||||||
|
|||||||
@@ -16,16 +16,14 @@ import logging
|
|||||||
import socket
|
import socket
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
try:
|
import six
|
||||||
from cStringIO import StringIO
|
from six.moves import range
|
||||||
except ImportError:
|
|
||||||
from StringIO import StringIO # ignore flake8 warning: # NOQA
|
|
||||||
|
|
||||||
from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
|
from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
|
||||||
AlreadyExists, InvalidRequest, Unauthorized,
|
AlreadyExists, InvalidRequest, Unauthorized,
|
||||||
UnsupportedOperation)
|
UnsupportedOperation)
|
||||||
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
|
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
|
||||||
int8_pack, int8_unpack)
|
int8_pack, int8_unpack, header_pack)
|
||||||
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
|
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
|
||||||
CounterColumnType, DateType, DecimalType,
|
CounterColumnType, DateType, DecimalType,
|
||||||
DoubleType, FloatType, Int32Type,
|
DoubleType, FloatType, Int32Type,
|
||||||
@@ -48,66 +46,75 @@ HEADER_DIRECTION_FROM_CLIENT = 0x00
|
|||||||
HEADER_DIRECTION_TO_CLIENT = 0x80
|
HEADER_DIRECTION_TO_CLIENT = 0x80
|
||||||
HEADER_DIRECTION_MASK = 0x80
|
HEADER_DIRECTION_MASK = 0x80
|
||||||
|
|
||||||
|
COMPRESSED_FLAG = 0x01
|
||||||
|
TRACING_FLAG = 0x02
|
||||||
|
|
||||||
_message_types_by_name = {}
|
_message_types_by_name = {}
|
||||||
_message_types_by_opcode = {}
|
_message_types_by_opcode = {}
|
||||||
|
|
||||||
|
|
||||||
class _register_msg_type(type):
|
class _RegisterMessageType(type):
|
||||||
def __init__(cls, name, bases, dct):
|
def __init__(cls, name, bases, dct):
|
||||||
if not name.startswith('_'):
|
if not name.startswith('_'):
|
||||||
_message_types_by_name[cls.name] = cls
|
_message_types_by_name[cls.name] = cls
|
||||||
_message_types_by_opcode[cls.opcode] = cls
|
_message_types_by_opcode[cls.opcode] = cls
|
||||||
|
|
||||||
|
|
||||||
|
@six.add_metaclass(_RegisterMessageType)
|
||||||
class _MessageType(object):
|
class _MessageType(object):
|
||||||
__metaclass__ = _register_msg_type
|
|
||||||
|
|
||||||
tracing = False
|
tracing = False
|
||||||
|
|
||||||
def to_string(self, stream_id, protocol_version, compression=None):
|
def to_binary(self, stream_id, protocol_version, compression=None):
|
||||||
body = StringIO()
|
body = six.BytesIO()
|
||||||
self.send_body(body, protocol_version)
|
self.send_body(body, protocol_version)
|
||||||
body = body.getvalue()
|
body = body.getvalue()
|
||||||
version = protocol_version | HEADER_DIRECTION_FROM_CLIENT
|
|
||||||
flags = 0
|
|
||||||
if compression is not None and len(body) > 0:
|
|
||||||
body = compression(body)
|
|
||||||
flags |= 0x01
|
|
||||||
if self.tracing:
|
|
||||||
flags |= 0x02
|
|
||||||
msglen = int32_pack(len(body))
|
|
||||||
msg_parts = map(int8_pack, (version, flags, stream_id, self.opcode)) + [msglen, body]
|
|
||||||
return ''.join(msg_parts)
|
|
||||||
|
|
||||||
def __str__(self):
|
flags = 0
|
||||||
paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in _get_params(self)]
|
if compression and len(body) > 0:
|
||||||
return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs))
|
body = compression(body)
|
||||||
__repr__ = __str__
|
flags |= COMPRESSED_FLAG
|
||||||
|
if self.tracing:
|
||||||
|
flags |= TRACING_FLAG
|
||||||
|
|
||||||
|
msg = six.BytesIO()
|
||||||
|
write_header(
|
||||||
|
msg,
|
||||||
|
protocol_version | HEADER_DIRECTION_FROM_CLIENT,
|
||||||
|
flags, stream_id, self.opcode, len(body)
|
||||||
|
)
|
||||||
|
msg.write(body)
|
||||||
|
|
||||||
|
return msg.getvalue()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
|
||||||
|
|
||||||
|
|
||||||
def _get_params(message_obj):
|
def _get_params(message_obj):
|
||||||
base_attrs = dir(_MessageType)
|
base_attrs = dir(_MessageType)
|
||||||
return [a for a in dir(message_obj)
|
return (
|
||||||
if a not in base_attrs and not a.startswith('_') and not callable(getattr(message_obj, a))]
|
(n, a) for n, a in message_obj.__dict__.items()
|
||||||
|
if n not in base_attrs and not n.startswith('_') and not callable(a)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode_response(stream_id, flags, opcode, body, decompressor=None):
|
def decode_response(stream_id, flags, opcode, body, decompressor=None):
|
||||||
if flags & 0x01:
|
if flags & COMPRESSED_FLAG:
|
||||||
if decompressor is None:
|
if decompressor is None:
|
||||||
raise Exception("No decompressor available for compressed frame!")
|
raise Exception("No de-compressor available for compressed frame!")
|
||||||
body = decompressor(body)
|
body = decompressor(body)
|
||||||
flags ^= 0x01
|
flags ^= COMPRESSED_FLAG
|
||||||
|
|
||||||
body = StringIO(body)
|
body = six.BytesIO(body)
|
||||||
if flags & 0x02:
|
if flags & TRACING_FLAG:
|
||||||
trace_id = UUID(bytes=body.read(16))
|
trace_id = UUID(bytes=body.read(16))
|
||||||
flags ^= 0x02
|
flags ^= TRACING_FLAG
|
||||||
else:
|
else:
|
||||||
trace_id = None
|
trace_id = None
|
||||||
|
|
||||||
if flags:
|
if flags:
|
||||||
log.warn("Unknown protocol flags set: %02x. May cause problems.", flags)
|
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
|
||||||
|
|
||||||
msg_class = _message_types_by_opcode[opcode]
|
msg_class = _message_types_by_opcode[opcode]
|
||||||
msg = msg_class.recv_body(body)
|
msg = msg_class.recv_body(body)
|
||||||
@@ -156,14 +163,14 @@ class ErrorMessage(_MessageType, Exception):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ErrorMessageSubclass(_register_msg_type):
|
class ErrorMessageSubclass(_RegisterMessageType):
|
||||||
def __init__(cls, name, bases, dct):
|
def __init__(cls, name, bases, dct):
|
||||||
if cls.error_code is not None:
|
if cls.error_code is not None: # Server has an error code of 0.
|
||||||
error_classes[cls.error_code] = cls
|
error_classes[cls.error_code] = cls
|
||||||
|
|
||||||
|
|
||||||
|
@six.add_metaclass(ErrorMessageSubclass)
|
||||||
class ErrorMessageSub(ErrorMessage):
|
class ErrorMessageSub(ErrorMessage):
|
||||||
__metaclass__ = ErrorMessageSubclass
|
|
||||||
error_code = None
|
error_code = None
|
||||||
|
|
||||||
|
|
||||||
@@ -511,7 +518,7 @@ class ResultMessage(_MessageType):
|
|||||||
def recv_results_rows(cls, f):
|
def recv_results_rows(cls, f):
|
||||||
paging_state, column_metadata = cls.recv_results_metadata(f)
|
paging_state, column_metadata = cls.recv_results_metadata(f)
|
||||||
rowcount = read_int(f)
|
rowcount = read_int(f)
|
||||||
rows = [cls.recv_row(f, len(column_metadata)) for x in xrange(rowcount)]
|
rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
|
||||||
colnames = [c[2] for c in column_metadata]
|
colnames = [c[2] for c in column_metadata]
|
||||||
coltypes = [c[3] for c in column_metadata]
|
coltypes = [c[3] for c in column_metadata]
|
||||||
return (
|
return (
|
||||||
@@ -538,7 +545,7 @@ class ResultMessage(_MessageType):
|
|||||||
ksname = read_string(f)
|
ksname = read_string(f)
|
||||||
cfname = read_string(f)
|
cfname = read_string(f)
|
||||||
column_metadata = []
|
column_metadata = []
|
||||||
for x in xrange(colcount):
|
for _ in range(colcount):
|
||||||
if glob_tblspec:
|
if glob_tblspec:
|
||||||
colksname = ksname
|
colksname = ksname
|
||||||
colcfname = cfname
|
colcfname = cfname
|
||||||
@@ -580,7 +587,7 @@ class ResultMessage(_MessageType):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def recv_row(f, colcount):
|
def recv_row(f, colcount):
|
||||||
return [read_value(f) for x in xrange(colcount)]
|
return [read_value(f) for _ in range(colcount)]
|
||||||
|
|
||||||
|
|
||||||
class PrepareMessage(_MessageType):
|
class PrepareMessage(_MessageType):
|
||||||
@@ -729,6 +736,14 @@ class EventMessage(_MessageType):
|
|||||||
return dict(change_type=change_type, keyspace=keyspace, table=table)
|
return dict(change_type=change_type, keyspace=keyspace, table=table)
|
||||||
|
|
||||||
|
|
||||||
|
def write_header(f, version, flags, stream_id, opcode, length):
|
||||||
|
"""
|
||||||
|
Write a CQL protocol frame header.
|
||||||
|
"""
|
||||||
|
f.write(header_pack(version, flags, stream_id, opcode))
|
||||||
|
write_int(f, length)
|
||||||
|
|
||||||
|
|
||||||
def read_byte(f):
|
def read_byte(f):
|
||||||
return int8_unpack(f.read(1))
|
return int8_unpack(f.read(1))
|
||||||
|
|
||||||
@@ -774,7 +789,7 @@ def read_binary_string(f):
|
|||||||
|
|
||||||
|
|
||||||
def write_string(f, s):
|
def write_string(f, s):
|
||||||
if isinstance(s, unicode):
|
if isinstance(s, six.text_type):
|
||||||
s = s.encode('utf8')
|
s = s.encode('utf8')
|
||||||
write_short(f, len(s))
|
write_short(f, len(s))
|
||||||
f.write(s)
|
f.write(s)
|
||||||
@@ -791,7 +806,7 @@ def read_longstring(f):
|
|||||||
|
|
||||||
|
|
||||||
def write_longstring(f, s):
|
def write_longstring(f, s):
|
||||||
if isinstance(s, unicode):
|
if isinstance(s, six.text_type):
|
||||||
s = s.encode('utf8')
|
s = s.encode('utf8')
|
||||||
write_int(f, len(s))
|
write_int(f, len(s))
|
||||||
f.write(s)
|
f.write(s)
|
||||||
@@ -799,7 +814,7 @@ def write_longstring(f, s):
|
|||||||
|
|
||||||
def read_stringlist(f):
|
def read_stringlist(f):
|
||||||
numstrs = read_short(f)
|
numstrs = read_short(f)
|
||||||
return [read_string(f) for x in xrange(numstrs)]
|
return [read_string(f) for _ in range(numstrs)]
|
||||||
|
|
||||||
|
|
||||||
def write_stringlist(f, stringlist):
|
def write_stringlist(f, stringlist):
|
||||||
@@ -811,7 +826,7 @@ def write_stringlist(f, stringlist):
|
|||||||
def read_stringmap(f):
|
def read_stringmap(f):
|
||||||
numpairs = read_short(f)
|
numpairs = read_short(f)
|
||||||
strmap = {}
|
strmap = {}
|
||||||
for x in xrange(numpairs):
|
for _ in range(numpairs):
|
||||||
k = read_string(f)
|
k = read_string(f)
|
||||||
strmap[k] = read_string(f)
|
strmap[k] = read_string(f)
|
||||||
return strmap
|
return strmap
|
||||||
@@ -827,7 +842,7 @@ def write_stringmap(f, strmap):
|
|||||||
def read_stringmultimap(f):
|
def read_stringmultimap(f):
|
||||||
numkeys = read_short(f)
|
numkeys = read_short(f)
|
||||||
strmmap = {}
|
strmmap = {}
|
||||||
for x in xrange(numkeys):
|
for _ in range(numkeys):
|
||||||
k = read_string(f)
|
k = read_string(f)
|
||||||
strmmap[k] = read_stringlist(f)
|
strmmap[k] = read_stringlist(f)
|
||||||
return strmmap
|
return strmmap
|
||||||
|
|||||||
@@ -12,21 +12,34 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
import calendar
|
import calendar
|
||||||
import datetime
|
import datetime
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
import six
|
||||||
|
|
||||||
from cassandra.util import OrderedDict
|
from cassandra.util import OrderedDict
|
||||||
|
|
||||||
|
if six.PY3:
|
||||||
|
long = int
|
||||||
|
|
||||||
|
|
||||||
def cql_quote(term):
|
def cql_quote(term):
|
||||||
if isinstance(term, unicode):
|
# The ordering of this method is important for the result of this method to
|
||||||
return "'%s'" % term.encode('utf8').replace("'", "''")
|
# be a native str type (for both Python 2 and 3)
|
||||||
elif isinstance(term, (str, bool)):
|
|
||||||
|
# Handle quoting of native str and bool types
|
||||||
|
if isinstance(term, (str, bool)):
|
||||||
return "'%s'" % str(term).replace("'", "''")
|
return "'%s'" % str(term).replace("'", "''")
|
||||||
|
# This branch of the if statement will only be used by Python 2 to catch
|
||||||
|
# unicode strings, text_type is used to prevent type errors with Python 3.
|
||||||
|
elif isinstance(term, six.text_type):
|
||||||
|
return "'%s'" % term.encode('utf8').replace("'", "''")
|
||||||
else:
|
else:
|
||||||
return str(term)
|
return str(term)
|
||||||
|
|
||||||
@@ -43,13 +56,16 @@ def cql_encode_str(val):
|
|||||||
return cql_quote(val)
|
return cql_quote(val)
|
||||||
|
|
||||||
|
|
||||||
if sys.version_info >= (2, 7):
|
if six.PY3:
|
||||||
def cql_encode_bytes(val):
|
def cql_encode_bytes(val):
|
||||||
return '0x' + hexlify(val)
|
return (b'0x' + hexlify(val)).decode('utf-8')
|
||||||
|
elif sys.version_info >= (2, 7):
|
||||||
|
def cql_encode_bytes(val): # noqa
|
||||||
|
return b'0x' + hexlify(val)
|
||||||
else:
|
else:
|
||||||
# python 2.6 requires string or read-only buffer for hexlify
|
# python 2.6 requires string or read-only buffer for hexlify
|
||||||
def cql_encode_bytes(val): # noqa
|
def cql_encode_bytes(val): # noqa
|
||||||
return '0x' + hexlify(buffer(val))
|
return b'0x' + hexlify(buffer(val))
|
||||||
|
|
||||||
|
|
||||||
def cql_encode_object(val):
|
def cql_encode_object(val):
|
||||||
@@ -71,11 +87,10 @@ def cql_encode_sequence(val):
|
|||||||
|
|
||||||
|
|
||||||
def cql_encode_map_collection(val):
|
def cql_encode_map_collection(val):
|
||||||
return '{ %s }' % ' , '.join(
|
return '{ %s }' % ' , '.join('%s : %s' % (
|
||||||
'%s : %s' % (
|
cql_encode_all_types(k),
|
||||||
cql_encode_all_types(k),
|
cql_encode_all_types(v)
|
||||||
cql_encode_all_types(v))
|
) for k, v in six.iteritems(val))
|
||||||
for k, v in val.iteritems())
|
|
||||||
|
|
||||||
|
|
||||||
def cql_encode_list_collection(val):
|
def cql_encode_list_collection(val):
|
||||||
@@ -92,13 +107,9 @@ def cql_encode_all_types(val):
|
|||||||
|
|
||||||
cql_encoders = {
|
cql_encoders = {
|
||||||
float: cql_encode_object,
|
float: cql_encode_object,
|
||||||
buffer: cql_encode_bytes,
|
|
||||||
bytearray: cql_encode_bytes,
|
bytearray: cql_encode_bytes,
|
||||||
str: cql_encode_str,
|
str: cql_encode_str,
|
||||||
unicode: cql_encode_unicode,
|
|
||||||
types.NoneType: cql_encode_none,
|
|
||||||
int: cql_encode_object,
|
int: cql_encode_object,
|
||||||
long: cql_encode_object,
|
|
||||||
UUID: cql_encode_object,
|
UUID: cql_encode_object,
|
||||||
datetime.datetime: cql_encode_datetime,
|
datetime.datetime: cql_encode_datetime,
|
||||||
datetime.date: cql_encode_date,
|
datetime.date: cql_encode_date,
|
||||||
@@ -110,3 +121,17 @@ cql_encoders = {
|
|||||||
frozenset: cql_encode_set_collection,
|
frozenset: cql_encode_set_collection,
|
||||||
types.GeneratorType: cql_encode_list_collection
|
types.GeneratorType: cql_encode_list_collection
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if six.PY2:
|
||||||
|
cql_encoders.update({
|
||||||
|
unicode: cql_encode_unicode,
|
||||||
|
buffer: cql_encode_bytes,
|
||||||
|
long: cql_encode_object,
|
||||||
|
types.NoneType: cql_encode_none,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
cql_encoders.update({
|
||||||
|
memoryview: cql_encode_bytes,
|
||||||
|
bytes: cql_encode_bytes,
|
||||||
|
type(None): cql_encode_none,
|
||||||
|
})
|
||||||
|
|||||||
@@ -20,6 +20,10 @@ import os
|
|||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
from threading import Event, Lock, Thread
|
from threading import Event, Lock, Thread
|
||||||
|
|
||||||
|
from six import BytesIO
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode
|
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode
|
||||||
try:
|
try:
|
||||||
from weakref import WeakSet
|
from weakref import WeakSet
|
||||||
@@ -28,11 +32,6 @@ except ImportError:
|
|||||||
|
|
||||||
import asyncore
|
import asyncore
|
||||||
|
|
||||||
try:
|
|
||||||
from cStringIO import StringIO
|
|
||||||
except ImportError:
|
|
||||||
from StringIO import StringIO # ignore flake8 warning: # NOQA
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ssl
|
import ssl
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -141,7 +140,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
|||||||
asyncore.dispatcher.__init__(self)
|
asyncore.dispatcher.__init__(self)
|
||||||
|
|
||||||
self.connected_event = Event()
|
self.connected_event = Event()
|
||||||
self._iobuf = StringIO()
|
self._iobuf = BytesIO()
|
||||||
|
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self.deque = deque()
|
self.deque = deque()
|
||||||
@@ -286,7 +285,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
|||||||
|
|
||||||
# leave leftover in current buffer
|
# leave leftover in current buffer
|
||||||
leftover = self._iobuf.read()
|
leftover = self._iobuf.read()
|
||||||
self._iobuf = StringIO()
|
self._iobuf = BytesIO()
|
||||||
self._iobuf.write(leftover)
|
self._iobuf.write(leftover)
|
||||||
|
|
||||||
self._total_reqd_bytes = 0
|
self._total_reqd_bytes = 0
|
||||||
@@ -302,7 +301,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
|||||||
sabs = self.out_buffer_size
|
sabs = self.out_buffer_size
|
||||||
if len(data) > sabs:
|
if len(data) > sabs:
|
||||||
chunks = []
|
chunks = []
|
||||||
for i in xrange(0, len(data), sabs):
|
for i in range(0, len(data), sabs):
|
||||||
chunks.append(data[i:i + sabs])
|
chunks.append(data[i:i + sabs])
|
||||||
else:
|
else:
|
||||||
chunks = [data]
|
chunks = [data]
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ import os
|
|||||||
import socket
|
import socket
|
||||||
from threading import Event, Lock, Thread
|
from threading import Event, Lock, Thread
|
||||||
|
|
||||||
|
from six import BytesIO
|
||||||
|
|
||||||
from cassandra import OperationTimedOut
|
from cassandra import OperationTimedOut
|
||||||
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
|
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
|
||||||
from cassandra.decoder import RegisterMessage
|
from cassandra.decoder import RegisterMessage
|
||||||
@@ -35,10 +37,6 @@ except ImportError:
|
|||||||
"for instructions on installing build dependencies and building "
|
"for instructions on installing build dependencies and building "
|
||||||
"the C extension.")
|
"the C extension.")
|
||||||
|
|
||||||
try:
|
|
||||||
from cStringIO import StringIO
|
|
||||||
except ImportError:
|
|
||||||
from StringIO import StringIO # ignore flake8 warning: # NOQA
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ssl
|
import ssl
|
||||||
@@ -197,7 +195,7 @@ class LibevConnection(Connection):
|
|||||||
Connection.__init__(self, *args, **kwargs)
|
Connection.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
self.connected_event = Event()
|
self.connected_event = Event()
|
||||||
self._iobuf = StringIO()
|
self._iobuf = BytesIO()
|
||||||
|
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self.deque = deque()
|
self.deque = deque()
|
||||||
@@ -323,7 +321,7 @@ class LibevConnection(Connection):
|
|||||||
|
|
||||||
# leave leftover in current buffer
|
# leave leftover in current buffer
|
||||||
leftover = self._iobuf.read()
|
leftover = self._iobuf.read()
|
||||||
self._iobuf = StringIO()
|
self._iobuf = BytesIO()
|
||||||
self._iobuf.write(leftover)
|
self._iobuf.write(leftover)
|
||||||
|
|
||||||
self._total_reqd_bytes = 0
|
self._total_reqd_bytes = 0
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import six
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
|
||||||
@@ -37,12 +38,24 @@ uint8_pack, uint8_unpack = _make_packer('>B')
|
|||||||
float_pack, float_unpack = _make_packer('>f')
|
float_pack, float_unpack = _make_packer('>f')
|
||||||
double_pack, double_unpack = _make_packer('>d')
|
double_pack, double_unpack = _make_packer('>d')
|
||||||
|
|
||||||
|
# Special case for cassandra header
|
||||||
|
header_struct = struct.Struct('>BBbB')
|
||||||
|
header_pack = header_struct.pack
|
||||||
|
header_unpack = header_struct.unpack
|
||||||
|
|
||||||
def varint_unpack(term):
|
|
||||||
val = int(term.encode('hex'), 16)
|
if six.PY3:
|
||||||
if (ord(term[0]) & 128) != 0:
|
def varint_unpack(term):
|
||||||
val = val - (1 << (len(term) * 8))
|
val = int(''.join("%02x" % i for i in term), 16)
|
||||||
return val
|
if (term[0] & 128) != 0:
|
||||||
|
val -= 1 << (len(term) * 8)
|
||||||
|
return val
|
||||||
|
else:
|
||||||
|
def varint_unpack(term): # noqa
|
||||||
|
val = int(term.encode('hex'), 16)
|
||||||
|
if (ord(term[0]) & 128) != 0:
|
||||||
|
val = val - (1 << (len(term) * 8))
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
def bitlength(n):
|
def bitlength(n):
|
||||||
@@ -56,16 +69,16 @@ def bitlength(n):
|
|||||||
def varint_pack(big):
|
def varint_pack(big):
|
||||||
pos = True
|
pos = True
|
||||||
if big == 0:
|
if big == 0:
|
||||||
return '\x00'
|
return b'\x00'
|
||||||
if big < 0:
|
if big < 0:
|
||||||
bytelength = bitlength(abs(big) - 1) / 8 + 1
|
bytelength = bitlength(abs(big) - 1) // 8 + 1
|
||||||
big = (1 << bytelength * 8) + big
|
big = (1 << bytelength * 8) + big
|
||||||
pos = False
|
pos = False
|
||||||
revbytes = []
|
revbytes = bytearray()
|
||||||
while big > 0:
|
while big > 0:
|
||||||
revbytes.append(chr(big & 0xff))
|
revbytes.append(big & 0xff)
|
||||||
big >>= 8
|
big >>= 8
|
||||||
if pos and ord(revbytes[-1]) & 0x80:
|
if pos and revbytes[-1] & 0x80:
|
||||||
revbytes.append('\x00')
|
revbytes.append(0)
|
||||||
revbytes.reverse()
|
revbytes.reverse()
|
||||||
return ''.join(revbytes)
|
return six.binary_type(revbytes)
|
||||||
|
|||||||
@@ -21,11 +21,12 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
import weakref
|
import weakref
|
||||||
|
import six
|
||||||
|
|
||||||
murmur3 = None
|
murmur3 = None
|
||||||
try:
|
try:
|
||||||
from murmur3 import murmur3
|
from cassandra.murmur3 import murmur3
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import cassandra.cqltypes as types
|
import cassandra.cqltypes as types
|
||||||
@@ -330,7 +331,7 @@ class Metadata(object):
|
|||||||
|
|
||||||
token_to_host_owner = {}
|
token_to_host_owner = {}
|
||||||
ring = []
|
ring = []
|
||||||
for host, token_strings in token_map.iteritems():
|
for host, token_strings in six.iteritems(token_map):
|
||||||
for token_string in token_strings:
|
for token_string in token_strings:
|
||||||
token = token_class(token_string)
|
token = token_class(token_string)
|
||||||
ring.append(token)
|
ring.append(token)
|
||||||
@@ -793,14 +794,18 @@ class TableMetadata(object):
|
|||||||
return list(sorted(ret))
|
return list(sorted(ret))
|
||||||
|
|
||||||
|
|
||||||
def protect_name(name):
|
if six.PY3:
|
||||||
if isinstance(name, unicode):
|
def protect_name(name):
|
||||||
name = name.encode('utf8')
|
return maybe_escape_name(name)
|
||||||
return maybe_escape_name(name)
|
else:
|
||||||
|
def protect_name(name):
|
||||||
|
if isinstance(name, six.text_type):
|
||||||
|
name = name.encode('utf8')
|
||||||
|
return maybe_escape_name(name)
|
||||||
|
|
||||||
|
|
||||||
def protect_names(names):
|
def protect_names(names):
|
||||||
return map(protect_name, names)
|
return [protect_name(n) for n in names]
|
||||||
|
|
||||||
|
|
||||||
def protect_value(value):
|
def protect_value(value):
|
||||||
@@ -1008,11 +1013,14 @@ class Token(object):
|
|||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.value == other.value
|
return self.value == other.value
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.value < other.value
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.value)
|
return hash(self.value)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<%s: %r>" % (self.__class__.__name__, self.value)
|
return "<%s: %s>" % (self.__class__.__name__, self.value)
|
||||||
__str__ = __repr__
|
__str__ = __repr__
|
||||||
|
|
||||||
MIN_LONG = -(2 ** 63)
|
MIN_LONG = -(2 ** 63)
|
||||||
@@ -1031,7 +1039,7 @@ class Murmur3Token(Token):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def hash_fn(cls, key):
|
def hash_fn(cls, key):
|
||||||
if murmur3 is not None:
|
if murmur3 is not None:
|
||||||
h = murmur3(key)
|
h = int(murmur3(key))
|
||||||
return h if h != MIN_LONG else MAX_LONG
|
return h if h != MIN_LONG else MAX_LONG
|
||||||
else:
|
else:
|
||||||
raise NoMurmur3()
|
raise NoMurmur3()
|
||||||
@@ -1048,6 +1056,8 @@ class MD5Token(Token):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def hash_fn(cls, key):
|
def hash_fn(cls, key):
|
||||||
|
if isinstance(key, six.text_type):
|
||||||
|
key = key.encode('UTF-8')
|
||||||
return abs(varint_unpack(md5(key).digest()))
|
return abs(varint_unpack(md5(key).digest()))
|
||||||
|
|
||||||
def __init__(self, token):
|
def __init__(self, token):
|
||||||
@@ -1062,7 +1072,7 @@ class BytesToken(Token):
|
|||||||
|
|
||||||
def __init__(self, token_string):
|
def __init__(self, token_string):
|
||||||
""" `token_string` should be string representing the token. """
|
""" `token_string` should be string representing the token. """
|
||||||
if not isinstance(token_string, basestring):
|
if not isinstance(token_string, six.string_types):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Tokens for ByteOrderedPartitioner should be strings (got %s)"
|
"Tokens for ByteOrderedPartitioner should be strings (got %s)"
|
||||||
% (type(token_string),))
|
% (type(token_string),))
|
||||||
|
|||||||
@@ -169,6 +169,18 @@ uint64_t MurmurHash3_x64_128 (const void * key, const int len,
|
|||||||
return h1;
|
return h1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
struct module_state {
|
||||||
|
PyObject *error;
|
||||||
|
};
|
||||||
|
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m))
|
||||||
|
#else
|
||||||
|
#define GETSTATE(m) (&_state)
|
||||||
|
static struct module_state _state;
|
||||||
|
#endif
|
||||||
|
|
||||||
static PyObject *
|
static PyObject *
|
||||||
murmur3(PyObject *self, PyObject *args)
|
murmur3(PyObject *self, PyObject *args)
|
||||||
{
|
{
|
||||||
@@ -186,23 +198,63 @@ murmur3(PyObject *self, PyObject *args)
|
|||||||
}
|
}
|
||||||
|
|
||||||
static PyMethodDef murmur3_methods[] = {
|
static PyMethodDef murmur3_methods[] = {
|
||||||
{"murmur3", murmur3, METH_VARARGS,
|
{"murmur3", murmur3, METH_VARARGS, "Make an x64 murmur3 64-bit hash value"},
|
||||||
"Make an x64 murmur3 64-bit hash value"},
|
|
||||||
|
|
||||||
{NULL, NULL, 0, NULL}
|
{NULL, NULL, 0, NULL}
|
||||||
};
|
};
|
||||||
|
|
||||||
#if PY_MAJOR_VERSION <= 2
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
|
||||||
PyMODINIT_FUNC
|
static int murmur3_traverse(PyObject *m, visitproc visit, void *arg) {
|
||||||
initmurmur3(void)
|
Py_VISIT(GETSTATE(m)->error);
|
||||||
{
|
return 0;
|
||||||
(void) Py_InitModule("murmur3", murmur3_methods);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int murmur3_clear(PyObject *m) {
|
||||||
|
Py_CLEAR(GETSTATE(m)->error);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct PyModuleDef moduledef = {
|
||||||
|
PyModuleDef_HEAD_INIT,
|
||||||
|
"murmur3",
|
||||||
|
NULL,
|
||||||
|
sizeof(struct module_state),
|
||||||
|
murmur3_methods,
|
||||||
|
NULL,
|
||||||
|
murmur3_traverse,
|
||||||
|
murmur3_clear,
|
||||||
|
NULL
|
||||||
|
};
|
||||||
|
|
||||||
|
#define INITERROR return NULL
|
||||||
|
|
||||||
|
PyObject *
|
||||||
|
PyInit_murmur3(void)
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
#define INITERROR return
|
||||||
|
|
||||||
/* Python 3.x */
|
void
|
||||||
// TODO
|
initmurmur3(void)
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
{
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
PyObject *module = PyModule_Create(&moduledef);
|
||||||
|
#else
|
||||||
|
PyObject *module = Py_InitModule("murmur3", murmur3_methods);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (module == NULL)
|
||||||
|
INITERROR;
|
||||||
|
struct module_state *st = GETSTATE(module);
|
||||||
|
|
||||||
|
st->error = PyErr_NewException("murmur3.Error", NULL, NULL);
|
||||||
|
if (st->error == NULL) {
|
||||||
|
Py_DECREF(module);
|
||||||
|
INITERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
return module;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,9 +16,12 @@ from itertools import islice, cycle, groupby, repeat
|
|||||||
import logging
|
import logging
|
||||||
from random import randint
|
from random import randint
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
import six
|
||||||
|
|
||||||
from cassandra import ConsistencyLevel
|
from cassandra import ConsistencyLevel
|
||||||
|
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -263,7 +266,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
|
|||||||
for host in islice(cycle(local_live), pos, pos + len(local_live)):
|
for host in islice(cycle(local_live), pos, pos + len(local_live)):
|
||||||
yield host
|
yield host
|
||||||
|
|
||||||
for dc, current_dc_hosts in self._dc_live_hosts.iteritems():
|
for dc, current_dc_hosts in six.iteritems(self._dc_live_hosts):
|
||||||
if dc == self.local_dc:
|
if dc == self.local_dc:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -529,7 +532,7 @@ class ExponentialReconnectionPolicy(ReconnectionPolicy):
|
|||||||
self.max_delay = max_delay
|
self.max_delay = max_delay
|
||||||
|
|
||||||
def new_schedule(self):
|
def new_schedule(self):
|
||||||
return (min(self.base_delay * (2 ** i), self.max_delay) for i in xrange(64))
|
return (min(self.base_delay * (2 ** i), self.max_delay) for i in range(64))
|
||||||
|
|
||||||
|
|
||||||
class WriteType(object):
|
class WriteType(object):
|
||||||
|
|||||||
@@ -150,8 +150,14 @@ class Host(object):
|
|||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.address == other.address
|
return self.address == other.address
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.address)
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.address < other.address
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.address
|
return str(self.address)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
dc = (" %s" % (self._datacenter,)) if self._datacenter else ""
|
dc = (" %s" % (self._datacenter,)) if self._datacenter else ""
|
||||||
@@ -178,7 +184,7 @@ class _ReconnectionHandler(object):
|
|||||||
log.debug("Reconnection handler was cancelled before starting")
|
log.debug("Reconnection handler was cancelled before starting")
|
||||||
return
|
return
|
||||||
|
|
||||||
first_delay = self.schedule.next()
|
first_delay = next(self.schedule)
|
||||||
self.scheduler.schedule(first_delay, self.run)
|
self.scheduler.schedule(first_delay, self.run)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
@@ -189,7 +195,7 @@ class _ReconnectionHandler(object):
|
|||||||
try:
|
try:
|
||||||
conn = self.try_reconnect()
|
conn = self.try_reconnect()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
next_delay = self.schedule.next()
|
next_delay = next(self.schedule)
|
||||||
if self.on_exception(exc, next_delay):
|
if self.on_exception(exc, next_delay):
|
||||||
self.scheduler.schedule(next_delay, self.run)
|
self.scheduler.schedule(next_delay, self.run)
|
||||||
else:
|
else:
|
||||||
@@ -260,8 +266,8 @@ class _HostReconnectionHandler(_ReconnectionHandler):
|
|||||||
if isinstance(exc, AuthenticationFailed):
|
if isinstance(exc, AuthenticationFailed):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
log.warn("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s",
|
log.warning("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s",
|
||||||
self.host, next_delay, exc)
|
self.host, next_delay, exc)
|
||||||
log.debug("Reconnection error details", exc_info=True)
|
log.debug("Reconnection error details", exc_info=True)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -371,8 +377,8 @@ class HostConnectionPool(object):
|
|||||||
def _create_new_connection(self):
|
def _create_new_connection(self):
|
||||||
try:
|
try:
|
||||||
self._add_conn_if_under_max()
|
self._add_conn_if_under_max()
|
||||||
except (ConnectionException, socket.error), exc:
|
except (ConnectionException, socket.error) as exc:
|
||||||
log.warn("Failed to create new connection to %s: %s", self.host, exc)
|
log.warning("Failed to create new connection to %s: %s", self.host, exc)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception("Unexpectedly failed to create new connection")
|
log.exception("Unexpectedly failed to create new connection")
|
||||||
finally:
|
finally:
|
||||||
@@ -404,7 +410,7 @@ class HostConnectionPool(object):
|
|||||||
self._signal_available_conn()
|
self._signal_available_conn()
|
||||||
return True
|
return True
|
||||||
except (ConnectionException, socket.error) as exc:
|
except (ConnectionException, socket.error) as exc:
|
||||||
log.warn("Failed to add new connection to pool for host %s: %s", self.host, exc)
|
log.warning("Failed to add new connection to pool for host %s: %s", self.host, exc)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.open_count -= 1
|
self.open_count -= 1
|
||||||
if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False):
|
if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False):
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from datetime import datetime, timedelta
|
|||||||
import re
|
import re
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
|
import six
|
||||||
|
|
||||||
from cassandra import ConsistencyLevel, OperationTimedOut
|
from cassandra import ConsistencyLevel, OperationTimedOut
|
||||||
from cassandra.cqltypes import unix_time_from_uuid1
|
from cassandra.cqltypes import unix_time_from_uuid1
|
||||||
@@ -113,8 +114,8 @@ class Statement(object):
|
|||||||
|
|
||||||
def _set_routing_key(self, key):
|
def _set_routing_key(self, key):
|
||||||
if isinstance(key, (list, tuple)):
|
if isinstance(key, (list, tuple)):
|
||||||
self._routing_key = "".join(struct.pack("HsB", len(component), component, 0)
|
self._routing_key = b"".join(struct.pack("HsB", len(component), component, 0)
|
||||||
for component in key)
|
for component in key)
|
||||||
else:
|
else:
|
||||||
self._routing_key = key
|
self._routing_key = key
|
||||||
|
|
||||||
@@ -408,7 +409,7 @@ class BoundStatement(Statement):
|
|||||||
val = self.values[statement_index]
|
val = self.values[statement_index]
|
||||||
components.append(struct.pack("HsB", len(val), val, 0))
|
components.append(struct.pack("HsB", len(val), val, 0))
|
||||||
|
|
||||||
self._routing_key = "".join(components)
|
self._routing_key = b"".join(components)
|
||||||
|
|
||||||
return self._routing_key
|
return self._routing_key
|
||||||
|
|
||||||
@@ -473,7 +474,7 @@ class BatchStatement(Statement):
|
|||||||
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level)
|
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level)
|
||||||
|
|
||||||
def add(self, statement, parameters=None):
|
def add(self, statement, parameters=None):
|
||||||
if isinstance(statement, basestring):
|
if isinstance(statement, six.string_types):
|
||||||
if parameters:
|
if parameters:
|
||||||
statement = bind_params(statement, parameters)
|
statement = bind_params(statement, parameters)
|
||||||
self._statements_and_parameters.append((False, statement, ()))
|
self._statements_and_parameters.append((False, statement, ()))
|
||||||
@@ -532,11 +533,9 @@ class ValueSequence(object):
|
|||||||
|
|
||||||
def bind_params(query, params):
|
def bind_params(query, params):
|
||||||
if isinstance(params, dict):
|
if isinstance(params, dict):
|
||||||
return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v))
|
return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
|
||||||
for k, v in params.iteritems())
|
|
||||||
else:
|
else:
|
||||||
return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v)
|
return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v) for v in params)
|
||||||
for v in params)
|
|
||||||
|
|
||||||
|
|
||||||
class TraceUnavailable(Exception):
|
class TraceUnavailable(Exception):
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from __future__ import with_statement
|
from __future__ import with_statement
|
||||||
|
|
||||||
from UserDict import DictMixin
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -28,6 +26,7 @@ except ImportError:
|
|||||||
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||||
# OTHER DEALINGS IN THE SOFTWARE.
|
# OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
from UserDict import DictMixin
|
||||||
|
|
||||||
class OrderedDict(dict, DictMixin): # noqa
|
class OrderedDict(dict, DictMixin): # noqa
|
||||||
""" A dictionary which maintains the insertion order of keys. """
|
""" A dictionary which maintains the insertion order of keys. """
|
||||||
@@ -80,9 +79,9 @@ except ImportError:
|
|||||||
if not self:
|
if not self:
|
||||||
raise KeyError('dictionary is empty')
|
raise KeyError('dictionary is empty')
|
||||||
if last:
|
if last:
|
||||||
key = reversed(self).next()
|
key = next(reversed(self))
|
||||||
else:
|
else:
|
||||||
key = iter(self).next()
|
key = next(iter(self))
|
||||||
value = self.pop(key)
|
value = self.pop(key)
|
||||||
return key, value
|
return key, value
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
blist
|
blist
|
||||||
futures
|
futures
|
||||||
scales
|
scales
|
||||||
|
six >=1.6
|
||||||
|
|||||||
26
setup.py
26
setup.py
@@ -12,16 +12,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import ez_setup
|
import ez_setup
|
||||||
ez_setup.use_setuptools()
|
ez_setup.use_setuptools()
|
||||||
|
|
||||||
run_gevent_nosetests = False
|
|
||||||
if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
|
if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
|
||||||
from gevent.monkey import patch_all
|
from gevent.monkey import patch_all
|
||||||
patch_all()
|
patch_all()
|
||||||
run_gevent_nosetests = True
|
|
||||||
|
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
from distutils.command.build_ext import build_ext
|
from distutils.command.build_ext import build_ext
|
||||||
@@ -48,10 +47,11 @@ with open("README.rst") as f:
|
|||||||
long_description = f.read()
|
long_description = f.read()
|
||||||
|
|
||||||
|
|
||||||
gevent_nosetests = None
|
try:
|
||||||
if run_gevent_nosetests:
|
|
||||||
from nose.commands import nosetests
|
from nose.commands import nosetests
|
||||||
|
except ImportError:
|
||||||
|
gevent_nosetests = None
|
||||||
|
else:
|
||||||
class gevent_nosetests(nosetests):
|
class gevent_nosetests(nosetests):
|
||||||
description = "run nosetests with gevent monkey patching"
|
description = "run nosetests with gevent monkey patching"
|
||||||
|
|
||||||
@@ -92,11 +92,11 @@ class DocCommand(Command):
|
|||||||
except subprocess.CalledProcessError as exc:
|
except subprocess.CalledProcessError as exc:
|
||||||
raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output))
|
raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output))
|
||||||
else:
|
else:
|
||||||
print output
|
print(output)
|
||||||
|
|
||||||
print ""
|
print("")
|
||||||
print "Documentation step '%s' performed, results here:" % mode
|
print("Documentation step '%s' performed, results here:" % mode)
|
||||||
print " %s/" % path
|
print(" %s/" % path)
|
||||||
|
|
||||||
|
|
||||||
class BuildFailed(Exception):
|
class BuildFailed(Exception):
|
||||||
@@ -181,7 +181,7 @@ def run_setup(extensions):
|
|||||||
kw['cmdclass']['build_ext'] = build_extensions
|
kw['cmdclass']['build_ext'] = build_extensions
|
||||||
kw['ext_modules'] = extensions
|
kw['ext_modules'] = extensions
|
||||||
|
|
||||||
dependencies = ['futures', 'scales', 'blist']
|
dependencies = ['futures', 'scales >=1.0.5', 'blist', 'six >=1.6']
|
||||||
if platform.python_implementation() != "CPython":
|
if platform.python_implementation() != "CPython":
|
||||||
dependencies.remove('blist')
|
dependencies.remove('blist')
|
||||||
|
|
||||||
@@ -196,7 +196,7 @@ def run_setup(extensions):
|
|||||||
packages=['cassandra', 'cassandra.io'],
|
packages=['cassandra', 'cassandra.io'],
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
install_requires=dependencies,
|
install_requires=dependencies,
|
||||||
tests_require=['nose', 'mock', 'ccm', 'unittest2', 'PyYAML', 'pytz'],
|
tests_require=['nose', 'mock', 'PyYAML', 'pytz'],
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Development Status :: 5 - Production/Stable',
|
'Development Status :: 5 - Production/Stable',
|
||||||
'Intended Audience :: Developers',
|
'Intended Audience :: Developers',
|
||||||
@@ -207,6 +207,10 @@ def run_setup(extensions):
|
|||||||
'Programming Language :: Python :: 2',
|
'Programming Language :: Python :: 2',
|
||||||
'Programming Language :: Python :: 2.6',
|
'Programming Language :: Python :: 2.6',
|
||||||
'Programming Language :: Python :: 2.7',
|
'Programming Language :: Python :: 2.7',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.3',
|
||||||
|
'Programming Language :: Python :: Implementation :: CPython',
|
||||||
|
'Programming Language :: Python :: Implementation :: PyPy',
|
||||||
'Topic :: Software Development :: Libraries :: Python Modules'
|
'Topic :: Software Development :: Libraries :: Python Modules'
|
||||||
],
|
],
|
||||||
**kw)
|
**kw)
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ except ImportError:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from six import print_
|
||||||
from threading import Event
|
from threading import Event
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
@@ -92,7 +94,7 @@ def get_node(node_id):
|
|||||||
|
|
||||||
|
|
||||||
def setup_package():
|
def setup_package():
|
||||||
print 'Using Cassandra version: %s' % CASSANDRA_VERSION
|
print_('Using Cassandra version: %s' % CASSANDRA_VERSION)
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
cluster = CCMCluster.load(path, CLUSTER_NAME)
|
cluster = CCMCluster.load(path, CLUSTER_NAME)
|
||||||
|
|||||||
@@ -12,7 +12,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import Queue
|
try:
|
||||||
|
from Queue import Queue, Empty
|
||||||
|
except ImportError:
|
||||||
|
from queue import Queue, Empty # noqa
|
||||||
|
|
||||||
from struct import pack
|
from struct import pack
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -31,13 +35,12 @@ def create_column_name(i):
|
|||||||
column_name = ''
|
column_name = ''
|
||||||
while True:
|
while True:
|
||||||
column_name += letters[i % 10]
|
column_name += letters[i % 10]
|
||||||
i /= 10
|
i = i // 10
|
||||||
if not i:
|
if not i:
|
||||||
break
|
break
|
||||||
|
|
||||||
if column_name == 'if':
|
if column_name == 'if':
|
||||||
column_name = 'special_case'
|
column_name = 'special_case'
|
||||||
|
|
||||||
return column_name
|
return column_name
|
||||||
|
|
||||||
|
|
||||||
@@ -56,15 +59,15 @@ class LargeDataTests(unittest.TestCase):
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
def batch_futures(self, session, statement_generator):
|
def batch_futures(self, session, statement_generator):
|
||||||
concurrency = 50
|
concurrency = 10
|
||||||
futures = Queue.Queue(maxsize=concurrency)
|
futures = Queue(maxsize=concurrency)
|
||||||
for i, statement in enumerate(statement_generator):
|
for i, statement in enumerate(statement_generator):
|
||||||
if i > 0 and i % (concurrency - 1) == 0:
|
if i > 0 and i % (concurrency - 1) == 0:
|
||||||
# clear the existing queue
|
# clear the existing queue
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
futures.get_nowait().result()
|
futures.get_nowait().result()
|
||||||
except Queue.Empty:
|
except Empty:
|
||||||
break
|
break
|
||||||
|
|
||||||
future = session.execute_async(statement)
|
future = session.execute_async(statement)
|
||||||
@@ -73,7 +76,7 @@ class LargeDataTests(unittest.TestCase):
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
futures.get_nowait().result()
|
futures.get_nowait().result()
|
||||||
except Queue.Empty:
|
except Empty:
|
||||||
break
|
break
|
||||||
|
|
||||||
def test_wide_rows(self):
|
def test_wide_rows(self):
|
||||||
@@ -81,15 +84,13 @@ class LargeDataTests(unittest.TestCase):
|
|||||||
session = self.make_session_and_keyspace()
|
session = self.make_session_and_keyspace()
|
||||||
session.execute('CREATE TABLE %s (k INT, i INT, PRIMARY KEY(k, i))' % table)
|
session.execute('CREATE TABLE %s (k INT, i INT, PRIMARY KEY(k, i))' % table)
|
||||||
|
|
||||||
|
prepared = session.prepare('INSERT INTO %s (k, i) VALUES (0, ?)' % (table, ))
|
||||||
|
|
||||||
# Write via async futures
|
# Write via async futures
|
||||||
self.batch_futures(
|
self.batch_futures(session, (prepared.bind((i, )) for i in range(100000)))
|
||||||
session,
|
|
||||||
(SimpleStatement('INSERT INTO %s (k, i) VALUES (0, %s)' % (table, i),
|
|
||||||
consistency_level=ConsistencyLevel.QUORUM)
|
|
||||||
for i in range(100000)))
|
|
||||||
|
|
||||||
# Read
|
# Read
|
||||||
results = session.execute('SELECT i FROM %s WHERE k=%s' % (table, 0))
|
results = session.execute('SELECT i FROM %s WHERE k=0' % (table, ))
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
for i, row in enumerate(results):
|
for i, row in enumerate(results):
|
||||||
@@ -120,18 +121,13 @@ class LargeDataTests(unittest.TestCase):
|
|||||||
session = self.make_session_and_keyspace()
|
session = self.make_session_and_keyspace()
|
||||||
session.execute('CREATE TABLE %s (k INT, i INT, v BLOB, PRIMARY KEY(k, i))' % table)
|
session.execute('CREATE TABLE %s (k INT, i INT, v BLOB, PRIMARY KEY(k, i))' % table)
|
||||||
|
|
||||||
# Build small ByteBuffer sample
|
prepared = session.prepare('INSERT INTO %s (k, i, v) VALUES (0, ?, 0xCAFE)' % (table, ))
|
||||||
bb = '0xCAFE'
|
|
||||||
|
|
||||||
# Write
|
# Write
|
||||||
self.batch_futures(
|
self.batch_futures(session, (prepared.bind((i, )) for i in range(100000)))
|
||||||
session,
|
|
||||||
(SimpleStatement('INSERT INTO %s (k, i, v) VALUES (0, %s, %s)' % (table, i, str(bb)),
|
|
||||||
consistency_level=ConsistencyLevel.QUORUM)
|
|
||||||
for i in range(100000)))
|
|
||||||
|
|
||||||
# Read
|
# Read
|
||||||
results = session.execute('SELECT i, v FROM %s WHERE k=%s' % (table, 0))
|
results = session.execute('SELECT i, v FROM %s WHERE k=0' % (table, ))
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
bb = pack('>H', 0xCAFE)
|
bb = pack('>H', 0xCAFE)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -125,7 +126,7 @@ def bootstrap(node, data_center=None, token=None):
|
|||||||
|
|
||||||
|
|
||||||
def ring(node):
|
def ring(node):
|
||||||
print 'From node%s:' % node
|
print('From node%s:' % node)
|
||||||
get_node(node).nodetool('ring')
|
get_node(node).nodetool('ring')
|
||||||
|
|
||||||
|
|
||||||
@@ -154,3 +155,5 @@ def wait_for_down(cluster, node, wait=True):
|
|||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
log.debug("Done waiting for node %s to be down", node)
|
log.debug("Done waiting for node %s to be down", node)
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
log.debug("Host is still marked up, waiting")
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class ClusterTests(unittest.TestCase):
|
|||||||
CREATE KEYSPACE clustertests
|
CREATE KEYSPACE clustertests
|
||||||
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
|
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
|
||||||
""")
|
""")
|
||||||
self.assertEquals(None, result)
|
self.assertEqual(None, result)
|
||||||
|
|
||||||
result = session.execute(
|
result = session.execute(
|
||||||
"""
|
"""
|
||||||
@@ -51,16 +51,16 @@ class ClusterTests(unittest.TestCase):
|
|||||||
PRIMARY KEY (a, b)
|
PRIMARY KEY (a, b)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
self.assertEquals(None, result)
|
self.assertEqual(None, result)
|
||||||
|
|
||||||
result = session.execute(
|
result = session.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c')
|
INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c')
|
||||||
""")
|
""")
|
||||||
self.assertEquals(None, result)
|
self.assertEqual(None, result)
|
||||||
|
|
||||||
result = session.execute("SELECT * FROM clustertests.cf0")
|
result = session.execute("SELECT * FROM clustertests.cf0")
|
||||||
self.assertEquals([('a', 'b', 'c')], result)
|
self.assertEqual([('a', 'b', 'c')], result)
|
||||||
|
|
||||||
cluster.shutdown()
|
cluster.shutdown()
|
||||||
|
|
||||||
@@ -75,15 +75,15 @@ class ClusterTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
INSERT INTO test3rf.test (k, v) VALUES (8889, 8889)
|
INSERT INTO test3rf.test (k, v) VALUES (8889, 8889)
|
||||||
""")
|
""")
|
||||||
self.assertEquals(None, result)
|
self.assertEqual(None, result)
|
||||||
|
|
||||||
result = session.execute("SELECT * FROM test3rf.test")
|
result = session.execute("SELECT * FROM test3rf.test")
|
||||||
self.assertEquals([(8889, 8889)], result)
|
self.assertEqual([(8889, 8889)], result)
|
||||||
|
|
||||||
# test_connect_on_keyspace
|
# test_connect_on_keyspace
|
||||||
session2 = cluster.connect('test3rf')
|
session2 = cluster.connect('test3rf')
|
||||||
result2 = session2.execute("SELECT * FROM test")
|
result2 = session2.execute("SELECT * FROM test")
|
||||||
self.assertEquals(result, result2)
|
self.assertEqual(result, result2)
|
||||||
|
|
||||||
def test_set_keyspace_twice(self):
|
def test_set_keyspace_twice(self):
|
||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class ClusterTests(unittest.TestCase):
|
|||||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||||
parameters = [(i, i) for i in range(num_statements)]
|
parameters = [(i, i) for i in range(num_statements)]
|
||||||
|
|
||||||
results = execute_concurrent(self.session, zip(statements, parameters))
|
results = execute_concurrent(self.session, list(zip(statements, parameters)))
|
||||||
self.assertEqual(num_statements, len(results))
|
self.assertEqual(num_statements, len(results))
|
||||||
self.assertEqual([(True, None)] * num_statements, results)
|
self.assertEqual([(True, None)] * num_statements, results)
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ class ClusterTests(unittest.TestCase):
|
|||||||
statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", ))
|
statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", ))
|
||||||
parameters = [(i, ) for i in range(num_statements)]
|
parameters = [(i, ) for i in range(num_statements)]
|
||||||
|
|
||||||
results = execute_concurrent(self.session, zip(statements, parameters))
|
results = execute_concurrent(self.session, list(zip(statements, parameters)))
|
||||||
self.assertEqual(num_statements, len(results))
|
self.assertEqual(num_statements, len(results))
|
||||||
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
|
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ class ClusterTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
InvalidRequest,
|
InvalidRequest,
|
||||||
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
|
execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)
|
||||||
|
|
||||||
def test_first_failure_client_side(self):
|
def test_first_failure_client_side(self):
|
||||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||||
@@ -92,7 +92,7 @@ class ClusterTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
TypeError,
|
TypeError,
|
||||||
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
|
execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)
|
||||||
|
|
||||||
def test_no_raise_on_first_failure(self):
|
def test_no_raise_on_first_failure(self):
|
||||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||||
@@ -101,7 +101,7 @@ class ClusterTests(unittest.TestCase):
|
|||||||
# we'll get an error back from the server
|
# we'll get an error back from the server
|
||||||
parameters[57] = ('efefef', 'awefawefawef')
|
parameters[57] = ('efefef', 'awefawefawef')
|
||||||
|
|
||||||
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
|
results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
|
||||||
for i, (success, result) in enumerate(results):
|
for i, (success, result) in enumerate(results):
|
||||||
if i == 57:
|
if i == 57:
|
||||||
self.assertFalse(success)
|
self.assertFalse(success)
|
||||||
@@ -115,9 +115,9 @@ class ClusterTests(unittest.TestCase):
|
|||||||
parameters = [(i, i) for i in range(100)]
|
parameters = [(i, i) for i in range(100)]
|
||||||
|
|
||||||
# the driver will raise an error when binding the params
|
# the driver will raise an error when binding the params
|
||||||
parameters[57] = i
|
parameters[57] = 1
|
||||||
|
|
||||||
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
|
results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
|
||||||
for i, (success, result) in enumerate(results):
|
for i, (success, result) in enumerate(results):
|
||||||
if i == 57:
|
if i == 57:
|
||||||
self.assertFalse(success)
|
self.assertFalse(success)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import six
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import unittest2 as unittest
|
import unittest2 as unittest
|
||||||
@@ -115,7 +116,7 @@ class SchemaMetadataTest(unittest.TestCase):
|
|||||||
|
|
||||||
def check_create_statement(self, tablemeta, original):
|
def check_create_statement(self, tablemeta, original):
|
||||||
recreate = tablemeta.as_cql_query(formatted=False)
|
recreate = tablemeta.as_cql_query(formatted=False)
|
||||||
self.assertEquals(original, recreate[:len(original)])
|
self.assertEqual(original, recreate[:len(original)])
|
||||||
self.session.execute("DROP TABLE %s.%s" % (self.ksname, self.cfname))
|
self.session.execute("DROP TABLE %s.%s" % (self.ksname, self.cfname))
|
||||||
self.session.execute(recreate)
|
self.session.execute(recreate)
|
||||||
|
|
||||||
@@ -289,7 +290,7 @@ class SchemaMetadataTest(unittest.TestCase):
|
|||||||
tablemeta = self.get_table_metadata()
|
tablemeta = self.get_table_metadata()
|
||||||
statements = tablemeta.export_as_string().strip()
|
statements = tablemeta.export_as_string().strip()
|
||||||
statements = [s.strip() for s in statements.split(';')]
|
statements = [s.strip() for s in statements.split(';')]
|
||||||
statements = filter(bool, statements)
|
statements = list(filter(bool, statements))
|
||||||
self.assertEqual(3, len(statements))
|
self.assertEqual(3, len(statements))
|
||||||
self.assertEqual(d_index, statements[1])
|
self.assertEqual(d_index, statements[1])
|
||||||
self.assertEqual(e_index, statements[2])
|
self.assertEqual(e_index, statements[2])
|
||||||
@@ -311,7 +312,7 @@ class TestCodeCoverage(unittest.TestCase):
|
|||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
cluster.connect()
|
cluster.connect()
|
||||||
|
|
||||||
self.assertIsInstance(cluster.metadata.export_schema_as_string(), basestring)
|
self.assertIsInstance(cluster.metadata.export_schema_as_string(), six.string_types)
|
||||||
|
|
||||||
def test_export_keyspace_schema(self):
|
def test_export_keyspace_schema(self):
|
||||||
"""
|
"""
|
||||||
@@ -323,8 +324,8 @@ class TestCodeCoverage(unittest.TestCase):
|
|||||||
|
|
||||||
for keyspace in cluster.metadata.keyspaces:
|
for keyspace in cluster.metadata.keyspaces:
|
||||||
keyspace_metadata = cluster.metadata.keyspaces[keyspace]
|
keyspace_metadata = cluster.metadata.keyspaces[keyspace]
|
||||||
self.assertIsInstance(keyspace_metadata.export_as_string(), basestring)
|
self.assertIsInstance(keyspace_metadata.export_as_string(), six.string_types)
|
||||||
self.assertIsInstance(keyspace_metadata.as_cql_query(), basestring)
|
self.assertIsInstance(keyspace_metadata.as_cql_query(), six.string_types)
|
||||||
|
|
||||||
def test_case_sensitivity(self):
|
def test_case_sensitivity(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
bound = prepared.bind(('a'))
|
bound = prepared.bind(('a'))
|
||||||
results = session.execute(bound)
|
results = session.execute(bound)
|
||||||
self.assertEquals(results, [('a', 'b', 'c')])
|
self.assertEqual(results, [('a', 'b', 'c')])
|
||||||
|
|
||||||
# test with new dict binding
|
# test with new dict binding
|
||||||
prepared = session.prepare(
|
prepared = session.prepare(
|
||||||
@@ -95,7 +95,7 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
bound = prepared.bind({'a': 'x'})
|
bound = prepared.bind({'a': 'x'})
|
||||||
results = session.execute(bound)
|
results = session.execute(bound)
|
||||||
self.assertEquals(results, [('x', 'y', 'z')])
|
self.assertEqual(results, [('x', 'y', 'z')])
|
||||||
|
|
||||||
def test_missing_primary_key(self):
|
def test_missing_primary_key(self):
|
||||||
"""
|
"""
|
||||||
@@ -148,7 +148,7 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
self.assertRaises(ValueError, prepared.bind, (1,2))
|
self.assertRaises(ValueError, prepared.bind, (1, 2))
|
||||||
|
|
||||||
def test_too_many_bind_values_dicts(self):
|
def test_too_many_bind_values_dicts(self):
|
||||||
"""
|
"""
|
||||||
@@ -196,7 +196,7 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
bound = prepared.bind((1,))
|
bound = prepared.bind((1,))
|
||||||
results = session.execute(bound)
|
results = session.execute(bound)
|
||||||
self.assertEquals(results[0].v, None)
|
self.assertEqual(results[0].v, None)
|
||||||
|
|
||||||
def test_none_values_dicts(self):
|
def test_none_values_dicts(self):
|
||||||
"""
|
"""
|
||||||
@@ -206,7 +206,6 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
session = cluster.connect()
|
session = cluster.connect()
|
||||||
|
|
||||||
|
|
||||||
# test with new dict binding
|
# test with new dict binding
|
||||||
prepared = session.prepare(
|
prepared = session.prepare(
|
||||||
"""
|
"""
|
||||||
@@ -225,7 +224,7 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
bound = prepared.bind({'k': 1})
|
bound = prepared.bind({'k': 1})
|
||||||
results = session.execute(bound)
|
results = session.execute(bound)
|
||||||
self.assertEquals(results[0].v, None)
|
self.assertEqual(results[0].v, None)
|
||||||
|
|
||||||
def test_async_binding(self):
|
def test_async_binding(self):
|
||||||
"""
|
"""
|
||||||
@@ -252,8 +251,7 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
future = session.execute_async(prepared, (873,))
|
future = session.execute_async(prepared, (873,))
|
||||||
results = future.result()
|
results = future.result()
|
||||||
self.assertEquals(results[0].v, None)
|
self.assertEqual(results[0].v, None)
|
||||||
|
|
||||||
|
|
||||||
def test_async_binding_dicts(self):
|
def test_async_binding_dicts(self):
|
||||||
"""
|
"""
|
||||||
@@ -280,4 +278,4 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
future = session.execute_async(prepared, {'k': 873})
|
future = session.execute_async(prepared, {'k': 873})
|
||||||
results = future.result()
|
results = future.result()
|
||||||
self.assertEquals(results[0].v, None)
|
self.assertEqual(results[0].v, None)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
|
|||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.policies import HostDistance
|
from cassandra.policies import HostDistance
|
||||||
|
|
||||||
from tests.integration import get_server_versions, PROTOCOL_VERSION
|
from tests.integration import PROTOCOL_VERSION
|
||||||
|
|
||||||
|
|
||||||
class QueryTest(unittest.TestCase):
|
class QueryTest(unittest.TestCase):
|
||||||
@@ -43,7 +43,7 @@ class QueryTest(unittest.TestCase):
|
|||||||
self.assertIsInstance(bound, BoundStatement)
|
self.assertIsInstance(bound, BoundStatement)
|
||||||
self.assertEqual(2, len(bound.values))
|
self.assertEqual(2, len(bound.values))
|
||||||
session.execute(bound)
|
session.execute(bound)
|
||||||
self.assertEqual(bound.routing_key, '\x00\x00\x00\x01')
|
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
|
||||||
|
|
||||||
def test_value_sequence(self):
|
def test_value_sequence(self):
|
||||||
"""
|
"""
|
||||||
@@ -102,7 +102,7 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
bound = prepared.bind((1, None))
|
bound = prepared.bind((1, None))
|
||||||
self.assertEqual(bound.routing_key, '\x00\x00\x00\x01')
|
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
|
||||||
|
|
||||||
def test_empty_routing_key_indexes(self):
|
def test_empty_routing_key_indexes(self):
|
||||||
"""
|
"""
|
||||||
@@ -158,7 +158,7 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
bound = prepared.bind((1, 2))
|
bound = prepared.bind((1, 2))
|
||||||
self.assertEqual(bound.routing_key, '\x04\x00\x00\x00\x04\x00\x00\x00')
|
self.assertEqual(bound.routing_key, b'\x04\x00\x00\x00\x04\x00\x00\x00')
|
||||||
|
|
||||||
def test_bound_keyspace(self):
|
def test_bound_keyspace(self):
|
||||||
"""
|
"""
|
||||||
@@ -326,14 +326,14 @@ class SerialConsistencyTests(unittest.TestCase):
|
|||||||
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
|
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
|
||||||
serial_consistency_level=ConsistencyLevel.SERIAL)
|
serial_consistency_level=ConsistencyLevel.SERIAL)
|
||||||
result = self.session.execute(statement)
|
result = self.session.execute(statement)
|
||||||
self.assertEquals(1, len(result))
|
self.assertEqual(1, len(result))
|
||||||
self.assertFalse(result[0].applied)
|
self.assertFalse(result[0].applied)
|
||||||
|
|
||||||
statement = SimpleStatement(
|
statement = SimpleStatement(
|
||||||
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
|
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
|
||||||
serial_consistency_level=ConsistencyLevel.SERIAL)
|
serial_consistency_level=ConsistencyLevel.SERIAL)
|
||||||
result = self.session.execute(statement)
|
result = self.session.execute(statement)
|
||||||
self.assertEquals(1, len(result))
|
self.assertEqual(1, len(result))
|
||||||
self.assertTrue(result[0].applied)
|
self.assertTrue(result[0].applied)
|
||||||
|
|
||||||
def test_conditional_update_with_prepared_statements(self):
|
def test_conditional_update_with_prepared_statements(self):
|
||||||
@@ -343,7 +343,7 @@ class SerialConsistencyTests(unittest.TestCase):
|
|||||||
|
|
||||||
statement.serial_consistency_level = ConsistencyLevel.SERIAL
|
statement.serial_consistency_level = ConsistencyLevel.SERIAL
|
||||||
result = self.session.execute(statement)
|
result = self.session.execute(statement)
|
||||||
self.assertEquals(1, len(result))
|
self.assertEqual(1, len(result))
|
||||||
self.assertFalse(result[0].applied)
|
self.assertFalse(result[0].applied)
|
||||||
|
|
||||||
statement = self.session.prepare(
|
statement = self.session.prepare(
|
||||||
@@ -351,7 +351,7 @@ class SerialConsistencyTests(unittest.TestCase):
|
|||||||
bound = statement.bind(())
|
bound = statement.bind(())
|
||||||
bound.serial_consistency_level = ConsistencyLevel.SERIAL
|
bound.serial_consistency_level = ConsistencyLevel.SERIAL
|
||||||
result = self.session.execute(statement)
|
result = self.session.execute(statement)
|
||||||
self.assertEquals(1, len(result))
|
self.assertEqual(1, len(result))
|
||||||
self.assertTrue(result[0].applied)
|
self.assertTrue(result[0].applied)
|
||||||
|
|
||||||
def test_bad_consistency_level(self):
|
def test_bad_consistency_level(self):
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ except ImportError:
|
|||||||
import unittest # noqa
|
import unittest # noqa
|
||||||
|
|
||||||
from itertools import cycle, count
|
from itertools import cycle, count
|
||||||
|
from six.moves import range
|
||||||
from threading import Event
|
from threading import Event
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
@@ -47,7 +48,7 @@ class QueryPagingTests(unittest.TestCase):
|
|||||||
def test_paging(self):
|
def test_paging(self):
|
||||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||||
[(i, ) for i in range(100)])
|
[(i, ) for i in range(100)])
|
||||||
execute_concurrent(self.session, statements_and_params)
|
execute_concurrent(self.session, list(statements_and_params))
|
||||||
|
|
||||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||||
|
|
||||||
@@ -153,7 +154,7 @@ class QueryPagingTests(unittest.TestCase):
|
|||||||
def test_async_paging(self):
|
def test_async_paging(self):
|
||||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||||
[(i, ) for i in range(100)])
|
[(i, ) for i in range(100)])
|
||||||
execute_concurrent(self.session, statements_and_params)
|
execute_concurrent(self.session, list(statements_and_params))
|
||||||
|
|
||||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||||
|
|
||||||
@@ -219,7 +220,7 @@ class QueryPagingTests(unittest.TestCase):
|
|||||||
def test_paging_callbacks(self):
|
def test_paging_callbacks(self):
|
||||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||||
[(i, ) for i in range(100)])
|
[(i, ) for i in range(100)])
|
||||||
execute_concurrent(self.session, statements_and_params)
|
execute_concurrent(self.session, list(statements_and_params))
|
||||||
|
|
||||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||||
|
|
||||||
@@ -232,7 +233,7 @@ class QueryPagingTests(unittest.TestCase):
|
|||||||
|
|
||||||
def handle_page(rows, future, counter):
|
def handle_page(rows, future, counter):
|
||||||
for row in rows:
|
for row in rows:
|
||||||
counter.next()
|
next(counter)
|
||||||
|
|
||||||
if future.has_more_pages:
|
if future.has_more_pages:
|
||||||
future.start_fetching_next_page()
|
future.start_fetching_next_page()
|
||||||
@@ -245,7 +246,7 @@ class QueryPagingTests(unittest.TestCase):
|
|||||||
|
|
||||||
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
||||||
event.wait()
|
event.wait()
|
||||||
self.assertEquals(counter.next(), 100)
|
self.assertEquals(next(counter), 100)
|
||||||
|
|
||||||
# simple statement
|
# simple statement
|
||||||
future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
|
future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
|
||||||
@@ -254,7 +255,7 @@ class QueryPagingTests(unittest.TestCase):
|
|||||||
|
|
||||||
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
||||||
event.wait()
|
event.wait()
|
||||||
self.assertEquals(counter.next(), 100)
|
self.assertEquals(next(counter), 100)
|
||||||
|
|
||||||
# prepared statement
|
# prepared statement
|
||||||
future = self.session.execute_async(prepared)
|
future = self.session.execute_async(prepared)
|
||||||
@@ -263,4 +264,4 @@ class QueryPagingTests(unittest.TestCase):
|
|||||||
|
|
||||||
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
||||||
event.wait()
|
event.wait()
|
||||||
self.assertEquals(counter.next(), 100)
|
self.assertEquals(next(counter), 100)
|
||||||
|
|||||||
@@ -17,8 +17,12 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
import unittest # noqa
|
import unittest # noqa
|
||||||
|
|
||||||
|
import logging
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import six
|
||||||
from uuid import uuid1, uuid4
|
from uuid import uuid1, uuid4
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -59,12 +63,16 @@ class TypeTests(unittest.TestCase):
|
|||||||
|
|
||||||
params = [
|
params = [
|
||||||
'key1',
|
'key1',
|
||||||
'blobyblob'.encode('hex')
|
b'blobyblob'
|
||||||
]
|
]
|
||||||
|
|
||||||
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s)'
|
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s)'
|
||||||
|
|
||||||
if self._cql_version >= (3, 1, 0):
|
# In python 3, the 'bytes' type is treated as a blob, so we can
|
||||||
|
# correctly encode it with hex notation.
|
||||||
|
# In python2, we don't treat the 'str' type as a blob, so we'll encode it
|
||||||
|
# as a string literal and have the following failure.
|
||||||
|
if six.PY2 and self._cql_version >= (3, 1, 0):
|
||||||
# Blob values can't be specified using string notation in CQL 3.1.0 and
|
# Blob values can't be specified using string notation in CQL 3.1.0 and
|
||||||
# above which is used by default in Cassandra 2.0.
|
# above which is used by default in Cassandra 2.0.
|
||||||
msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*'
|
msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*'
|
||||||
@@ -74,13 +82,13 @@ class TypeTests(unittest.TestCase):
|
|||||||
s.execute(query, params)
|
s.execute(query, params)
|
||||||
expected_vals = [
|
expected_vals = [
|
||||||
'key1',
|
'key1',
|
||||||
'blobyblob'
|
bytearray(b'blobyblob')
|
||||||
]
|
]
|
||||||
|
|
||||||
results = s.execute("SELECT * FROM mytable")
|
results = s.execute("SELECT * FROM mytable")
|
||||||
|
|
||||||
for expected, actual in zip(expected_vals, results[0]):
|
for expected, actual in zip(expected_vals, results[0]):
|
||||||
self.assertEquals(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
def test_blob_type_as_bytearray(self):
|
def test_blob_type_as_bytearray(self):
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
@@ -101,7 +109,7 @@ class TypeTests(unittest.TestCase):
|
|||||||
|
|
||||||
params = [
|
params = [
|
||||||
'key1',
|
'key1',
|
||||||
bytearray('blob1', 'hex')
|
bytearray(b'blob1')
|
||||||
]
|
]
|
||||||
|
|
||||||
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s);'
|
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s);'
|
||||||
@@ -109,13 +117,13 @@ class TypeTests(unittest.TestCase):
|
|||||||
|
|
||||||
expected_vals = [
|
expected_vals = [
|
||||||
'key1',
|
'key1',
|
||||||
bytearray('blob1', 'hex')
|
bytearray(b'blob1')
|
||||||
]
|
]
|
||||||
|
|
||||||
results = s.execute("SELECT * FROM mytable")
|
results = s.execute("SELECT * FROM mytable")
|
||||||
|
|
||||||
for expected, actual in zip(expected_vals, results[0]):
|
for expected, actual in zip(expected_vals, results[0]):
|
||||||
self.assertEquals(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
create_type_table = """
|
create_type_table = """
|
||||||
CREATE TABLE mytable (
|
CREATE TABLE mytable (
|
||||||
@@ -208,7 +216,7 @@ class TypeTests(unittest.TestCase):
|
|||||||
results = s.execute("SELECT * FROM mytable")
|
results = s.execute("SELECT * FROM mytable")
|
||||||
|
|
||||||
for expected, actual in zip(expected_vals, results[0]):
|
for expected, actual in zip(expected_vals, results[0]):
|
||||||
self.assertEquals(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
# try the same thing with a prepared statement
|
# try the same thing with a prepared statement
|
||||||
prepared = s.prepare("""
|
prepared = s.prepare("""
|
||||||
@@ -221,7 +229,7 @@ class TypeTests(unittest.TestCase):
|
|||||||
results = s.execute("SELECT * FROM mytable")
|
results = s.execute("SELECT * FROM mytable")
|
||||||
|
|
||||||
for expected, actual in zip(expected_vals, results[0]):
|
for expected, actual in zip(expected_vals, results[0]):
|
||||||
self.assertEquals(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
# query with prepared statement
|
# query with prepared statement
|
||||||
prepared = s.prepare("""
|
prepared = s.prepare("""
|
||||||
@@ -230,14 +238,14 @@ class TypeTests(unittest.TestCase):
|
|||||||
results = s.execute(prepared.bind(()))
|
results = s.execute(prepared.bind(()))
|
||||||
|
|
||||||
for expected, actual in zip(expected_vals, results[0]):
|
for expected, actual in zip(expected_vals, results[0]):
|
||||||
self.assertEquals(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
# query with prepared statement, no explicit columns
|
# query with prepared statement, no explicit columns
|
||||||
prepared = s.prepare("""SELECT * FROM mytable""")
|
prepared = s.prepare("""SELECT * FROM mytable""")
|
||||||
results = s.execute(prepared.bind(()))
|
results = s.execute(prepared.bind(()))
|
||||||
|
|
||||||
for expected, actual in zip(expected_vals, results[0]):
|
for expected, actual in zip(expected_vals, results[0]):
|
||||||
self.assertEquals(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
def test_empty_strings_and_nones(self):
|
def test_empty_strings_and_nones(self):
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
@@ -265,11 +273,11 @@ class TypeTests(unittest.TestCase):
|
|||||||
# insert empty strings for string-like fields and fetch them
|
# insert empty strings for string-like fields and fetch them
|
||||||
s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)",
|
s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)",
|
||||||
('', '', '', [''], {'': 3}))
|
('', '', '', [''], {'': 3}))
|
||||||
self.assertEquals(
|
self.assertEqual(
|
||||||
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
|
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
|
||||||
s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0])
|
s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0])
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEqual(
|
||||||
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
|
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
|
||||||
s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0])
|
s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0])
|
||||||
|
|
||||||
@@ -363,7 +371,7 @@ class TypeTests(unittest.TestCase):
|
|||||||
""" Ensure timezone-aware datetimes are converted to timestamps correctly """
|
""" Ensure timezone-aware datetimes are converted to timestamps correctly """
|
||||||
try:
|
try:
|
||||||
import pytz
|
import pytz
|
||||||
except ImportError, exc:
|
except ImportError as exc:
|
||||||
raise unittest.SkipTest('pytz is not available: %r' % (exc,))
|
raise unittest.SkipTest('pytz is not available: %r' % (exc,))
|
||||||
|
|
||||||
dt = datetime(1997, 8, 29, 11, 14)
|
dt = datetime(1997, 8, 29, 11, 14)
|
||||||
@@ -381,10 +389,10 @@ class TypeTests(unittest.TestCase):
|
|||||||
# test non-prepared statement
|
# test non-prepared statement
|
||||||
s.execute("INSERT INTO mytable (a, b) VALUES ('key1', %s)", parameters=(dt,))
|
s.execute("INSERT INTO mytable (a, b) VALUES ('key1', %s)", parameters=(dt,))
|
||||||
result = s.execute("SELECT b FROM mytable WHERE a='key1'")[0].b
|
result = s.execute("SELECT b FROM mytable WHERE a='key1'")[0].b
|
||||||
self.assertEquals(dt.utctimetuple(), result.utctimetuple())
|
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
|
||||||
|
|
||||||
# test prepared statement
|
# test prepared statement
|
||||||
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES ('key2', ?)")
|
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES ('key2', ?)")
|
||||||
s.execute(prepared, parameters=(dt,))
|
s.execute(prepared, parameters=(dt,))
|
||||||
result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b
|
result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b
|
||||||
self.assertEquals(dt.utctimetuple(), result.utctimetuple())
|
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import six
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import unittest2 as unittest
|
import unittest2 as unittest
|
||||||
@@ -19,7 +20,9 @@ except ImportError:
|
|||||||
|
|
||||||
import errno
|
import errno
|
||||||
import os
|
import os
|
||||||
from StringIO import StringIO
|
|
||||||
|
from six import BytesIO
|
||||||
|
|
||||||
import socket
|
import socket
|
||||||
from socket import error as socket_error
|
from socket import error as socket_error
|
||||||
|
|
||||||
@@ -55,7 +58,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
def make_header_prefix(self, message_class, version=2, stream_id=0):
|
def make_header_prefix(self, message_class, version=2, stream_id=0):
|
||||||
return ''.join(map(uint8_pack, [
|
return six.binary_type().join(map(uint8_pack, [
|
||||||
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
|
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
|
||||||
0, # flags (compression)
|
0, # flags (compression)
|
||||||
stream_id,
|
stream_id,
|
||||||
@@ -63,7 +66,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
]))
|
]))
|
||||||
|
|
||||||
def make_options_body(self):
|
def make_options_body(self):
|
||||||
options_buf = StringIO()
|
options_buf = BytesIO()
|
||||||
write_stringmultimap(options_buf, {
|
write_stringmultimap(options_buf, {
|
||||||
'CQL_VERSION': ['3.0.1'],
|
'CQL_VERSION': ['3.0.1'],
|
||||||
'COMPRESSION': []
|
'COMPRESSION': []
|
||||||
@@ -71,12 +74,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
return options_buf.getvalue()
|
return options_buf.getvalue()
|
||||||
|
|
||||||
def make_error_body(self, code, msg):
|
def make_error_body(self, code, msg):
|
||||||
buf = StringIO()
|
buf = BytesIO()
|
||||||
write_int(buf, code)
|
write_int(buf, code)
|
||||||
write_string(buf, msg)
|
write_string(buf, msg)
|
||||||
return buf.getvalue()
|
return buf.getvalue()
|
||||||
|
|
||||||
def make_msg(self, header, body=""):
|
def make_msg(self, header, body=six.binary_type()):
|
||||||
return header + uint32_pack(len(body)) + body
|
return header + uint32_pack(len(body)) + body
|
||||||
|
|
||||||
def test_successful_connection(self, *args):
|
def test_successful_connection(self, *args):
|
||||||
@@ -105,12 +108,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
# get a connection that's already fully started
|
# get a connection that's already fully started
|
||||||
c = self.test_successful_connection()
|
c = self.test_successful_connection()
|
||||||
|
|
||||||
header = '\x00\x00\x00\x00' + int32_pack(20000)
|
header = six.b('\x00\x00\x00\x00') + int32_pack(20000)
|
||||||
responses = [
|
responses = [
|
||||||
header + ('a' * (4096 - len(header))),
|
header + (six.b('a') * (4096 - len(header))),
|
||||||
'a' * 4096,
|
six.b('a') * 4096,
|
||||||
socket_error(errno.EAGAIN),
|
socket_error(errno.EAGAIN),
|
||||||
'a' * 100,
|
six.b('a') * 100,
|
||||||
socket_error(errno.EAGAIN)]
|
socket_error(errno.EAGAIN)]
|
||||||
|
|
||||||
def side_effect(*args):
|
def side_effect(*args):
|
||||||
@@ -122,17 +125,17 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
|
|
||||||
c.socket.recv.side_effect = side_effect
|
c.socket.recv.side_effect = side_effect
|
||||||
c.handle_read()
|
c.handle_read()
|
||||||
self.assertEquals(c._total_reqd_bytes, 20000 + len(header))
|
self.assertEqual(c._total_reqd_bytes, 20000 + len(header))
|
||||||
# the EAGAIN prevents it from reading the last 100 bytes
|
# the EAGAIN prevents it from reading the last 100 bytes
|
||||||
c._iobuf.seek(0, os.SEEK_END)
|
c._iobuf.seek(0, os.SEEK_END)
|
||||||
pos = c._iobuf.tell()
|
pos = c._iobuf.tell()
|
||||||
self.assertEquals(pos, 4096 + 4096)
|
self.assertEqual(pos, 4096 + 4096)
|
||||||
|
|
||||||
# now tell it to read the last 100 bytes
|
# now tell it to read the last 100 bytes
|
||||||
c.handle_read()
|
c.handle_read()
|
||||||
c._iobuf.seek(0, os.SEEK_END)
|
c._iobuf.seek(0, os.SEEK_END)
|
||||||
pos = c._iobuf.tell()
|
pos = c._iobuf.tell()
|
||||||
self.assertEquals(pos, 4096 + 4096 + 100)
|
self.assertEqual(pos, 4096 + 4096 + 100)
|
||||||
|
|
||||||
def test_protocol_error(self, *args):
|
def test_protocol_error(self, *args):
|
||||||
c = self.make_connection()
|
c = self.make_connection()
|
||||||
@@ -237,14 +240,13 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
options = self.make_options_body()
|
options = self.make_options_body()
|
||||||
message = self.make_msg(header, options)
|
message = self.make_msg(header, options)
|
||||||
|
|
||||||
# read in the first byte
|
c.socket.recv.return_value = message[0:1]
|
||||||
c.socket.recv.return_value = message[0]
|
|
||||||
c.handle_read()
|
c.handle_read()
|
||||||
self.assertEquals(c._iobuf.getvalue(), message[0])
|
self.assertEqual(c._iobuf.getvalue(), message[0:1])
|
||||||
|
|
||||||
c.socket.recv.return_value = message[1:]
|
c.socket.recv.return_value = message[1:]
|
||||||
c.handle_read()
|
c.handle_read()
|
||||||
self.assertEquals("", c._iobuf.getvalue())
|
self.assertEqual(six.binary_type(), c._iobuf.getvalue())
|
||||||
|
|
||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write()
|
c.handle_write()
|
||||||
@@ -266,12 +268,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
# read in the first nine bytes
|
# read in the first nine bytes
|
||||||
c.socket.recv.return_value = message[:9]
|
c.socket.recv.return_value = message[:9]
|
||||||
c.handle_read()
|
c.handle_read()
|
||||||
self.assertEquals(c._iobuf.getvalue(), message[:9])
|
self.assertEqual(c._iobuf.getvalue(), message[:9])
|
||||||
|
|
||||||
# ... then read in the rest
|
# ... then read in the rest
|
||||||
c.socket.recv.return_value = message[9:]
|
c.socket.recv.return_value = message[9:]
|
||||||
c.handle_read()
|
c.handle_read()
|
||||||
self.assertEquals("", c._iobuf.getvalue())
|
self.assertEqual(six.binary_type(), c._iobuf.getvalue())
|
||||||
|
|
||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write()
|
c.handle_write()
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ except ImportError:
|
|||||||
|
|
||||||
import errno
|
import errno
|
||||||
import os
|
import os
|
||||||
from StringIO import StringIO
|
|
||||||
|
from six.moves import StringIO
|
||||||
|
|
||||||
from socket import error as socket_error
|
from socket import error as socket_error
|
||||||
|
|
||||||
from mock import patch, Mock
|
from mock import patch, Mock
|
||||||
@@ -122,17 +124,17 @@ class LibevConnectionTest(unittest.TestCase):
|
|||||||
|
|
||||||
c._socket.recv.side_effect = side_effect
|
c._socket.recv.side_effect = side_effect
|
||||||
c.handle_read(None, 0)
|
c.handle_read(None, 0)
|
||||||
self.assertEquals(c._total_reqd_bytes, 20000 + len(header))
|
self.assertEqual(c._total_reqd_bytes, 20000 + len(header))
|
||||||
# the EAGAIN prevents it from reading the last 100 bytes
|
# the EAGAIN prevents it from reading the last 100 bytes
|
||||||
c._iobuf.seek(0, os.SEEK_END)
|
c._iobuf.seek(0, os.SEEK_END)
|
||||||
pos = c._iobuf.tell()
|
pos = c._iobuf.tell()
|
||||||
self.assertEquals(pos, 4096 + 4096)
|
self.assertEqual(pos, 4096 + 4096)
|
||||||
|
|
||||||
# now tell it to read the last 100 bytes
|
# now tell it to read the last 100 bytes
|
||||||
c.handle_read(None, 0)
|
c.handle_read(None, 0)
|
||||||
c._iobuf.seek(0, os.SEEK_END)
|
c._iobuf.seek(0, os.SEEK_END)
|
||||||
pos = c._iobuf.tell()
|
pos = c._iobuf.tell()
|
||||||
self.assertEquals(pos, 4096 + 4096 + 100)
|
self.assertEqual(pos, 4096 + 4096 + 100)
|
||||||
|
|
||||||
def test_protocol_error(self, *args):
|
def test_protocol_error(self, *args):
|
||||||
c = self.make_connection()
|
c = self.make_connection()
|
||||||
@@ -240,11 +242,11 @@ class LibevConnectionTest(unittest.TestCase):
|
|||||||
# read in the first byte
|
# read in the first byte
|
||||||
c._socket.recv.return_value = message[0]
|
c._socket.recv.return_value = message[0]
|
||||||
c.handle_read(None, 0)
|
c.handle_read(None, 0)
|
||||||
self.assertEquals(c._iobuf.getvalue(), message[0])
|
self.assertEqual(c._iobuf.getvalue(), message[0])
|
||||||
|
|
||||||
c._socket.recv.return_value = message[1:]
|
c._socket.recv.return_value = message[1:]
|
||||||
c.handle_read(None, 0)
|
c.handle_read(None, 0)
|
||||||
self.assertEquals("", c._iobuf.getvalue())
|
self.assertEqual("", c._iobuf.getvalue())
|
||||||
|
|
||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write(None, 0)
|
c.handle_write(None, 0)
|
||||||
@@ -266,12 +268,12 @@ class LibevConnectionTest(unittest.TestCase):
|
|||||||
# read in the first nine bytes
|
# read in the first nine bytes
|
||||||
c._socket.recv.return_value = message[:9]
|
c._socket.recv.return_value = message[:9]
|
||||||
c.handle_read(None, 0)
|
c.handle_read(None, 0)
|
||||||
self.assertEquals(c._iobuf.getvalue(), message[:9])
|
self.assertEqual(c._iobuf.getvalue(), message[:9])
|
||||||
|
|
||||||
# ... then read in the rest
|
# ... then read in the rest
|
||||||
c._socket.recv.return_value = message[9:]
|
c._socket.recv.return_value = message[9:]
|
||||||
c.handle_read(None, 0)
|
c.handle_read(None, 0)
|
||||||
self.assertEquals("", c._iobuf.getvalue())
|
self.assertEqual("", c._iobuf.getvalue())
|
||||||
|
|
||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write(None, 0)
|
c.handle_write(None, 0)
|
||||||
|
|||||||
@@ -11,17 +11,18 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from cassandra.cluster import Cluster
|
import six
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import unittest2 as unittest
|
import unittest2 as unittest
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import unittest # noqa
|
import unittest # noqa
|
||||||
|
|
||||||
from StringIO import StringIO
|
from six import BytesIO
|
||||||
|
|
||||||
from mock import Mock, ANY
|
from mock import Mock, ANY
|
||||||
|
|
||||||
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT,
|
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT,
|
||||||
HEADER_DIRECTION_FROM_CLIENT, ProtocolError)
|
HEADER_DIRECTION_FROM_CLIENT, ProtocolError)
|
||||||
from cassandra.decoder import (write_stringmultimap, write_int, write_string,
|
from cassandra.decoder import (write_stringmultimap, write_int, write_string,
|
||||||
@@ -40,7 +41,7 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
def make_header_prefix(self, message_class, version=2, stream_id=0):
|
def make_header_prefix(self, message_class, version=2, stream_id=0):
|
||||||
return ''.join(map(uint8_pack, [
|
return six.binary_type().join(map(uint8_pack, [
|
||||||
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
|
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
|
||||||
0, # flags (compression)
|
0, # flags (compression)
|
||||||
stream_id,
|
stream_id,
|
||||||
@@ -48,7 +49,7 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
]))
|
]))
|
||||||
|
|
||||||
def make_options_body(self):
|
def make_options_body(self):
|
||||||
options_buf = StringIO()
|
options_buf = BytesIO()
|
||||||
write_stringmultimap(options_buf, {
|
write_stringmultimap(options_buf, {
|
||||||
'CQL_VERSION': ['3.0.1'],
|
'CQL_VERSION': ['3.0.1'],
|
||||||
'COMPRESSION': []
|
'COMPRESSION': []
|
||||||
@@ -56,7 +57,7 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
return options_buf.getvalue()
|
return options_buf.getvalue()
|
||||||
|
|
||||||
def make_error_body(self, code, msg):
|
def make_error_body(self, code, msg):
|
||||||
buf = StringIO()
|
buf = BytesIO()
|
||||||
write_int(buf, code)
|
write_int(buf, code)
|
||||||
write_string(buf, msg)
|
write_string(buf, msg)
|
||||||
return buf.getvalue()
|
return buf.getvalue()
|
||||||
@@ -88,12 +89,12 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
c.defunct = Mock()
|
c.defunct = Mock()
|
||||||
|
|
||||||
# read in a SupportedMessage response
|
# read in a SupportedMessage response
|
||||||
header = ''.join(map(uint8_pack, [
|
header = six.binary_type().join(uint8_pack(i) for i in (
|
||||||
0xff & (HEADER_DIRECTION_FROM_CLIENT | self.protocol_version),
|
0xff & (HEADER_DIRECTION_FROM_CLIENT | self.protocol_version),
|
||||||
0, # flags (compression)
|
0, # flags (compression)
|
||||||
0,
|
0,
|
||||||
SupportedMessage.opcode # opcode
|
SupportedMessage.opcode # opcode
|
||||||
]))
|
))
|
||||||
options = self.make_options_body()
|
options = self.make_options_body()
|
||||||
message = self.make_msg(header, options)
|
message = self.make_msg(header, options)
|
||||||
c.process_msg(message, len(message) - 8)
|
c.process_msg(message, len(message) - 8)
|
||||||
@@ -130,7 +131,7 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
# read in a SupportedMessage response
|
# read in a SupportedMessage response
|
||||||
header = self.make_header_prefix(SupportedMessage)
|
header = self.make_header_prefix(SupportedMessage)
|
||||||
|
|
||||||
options_buf = StringIO()
|
options_buf = BytesIO()
|
||||||
write_stringmultimap(options_buf, {
|
write_stringmultimap(options_buf, {
|
||||||
'CQL_VERSION': ['7.8.9'],
|
'CQL_VERSION': ['7.8.9'],
|
||||||
'COMPRESSION': []
|
'COMPRESSION': []
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from cassandra.marshal import bitlength
|
|||||||
try:
|
try:
|
||||||
import unittest2 as unittest
|
import unittest2 as unittest
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import unittest # noqa
|
import unittest # noqa
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -33,57 +33,57 @@ from cassandra.util import OrderedDict
|
|||||||
|
|
||||||
marshalled_value_pairs = (
|
marshalled_value_pairs = (
|
||||||
# binary form, type, python native type
|
# binary form, type, python native type
|
||||||
('lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'),
|
(b'lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'),
|
||||||
('', 'AsciiType', ''),
|
(b'', 'AsciiType', ''),
|
||||||
('\x01', 'BooleanType', True),
|
(b'\x01', 'BooleanType', True),
|
||||||
('\x00', 'BooleanType', False),
|
(b'\x00', 'BooleanType', False),
|
||||||
('', 'BooleanType', None),
|
(b'', 'BooleanType', None),
|
||||||
('\xff\xfe\xfd\xfc\xfb', 'BytesType', '\xff\xfe\xfd\xfc\xfb'),
|
(b'\xff\xfe\xfd\xfc\xfb', 'BytesType', b'\xff\xfe\xfd\xfc\xfb'),
|
||||||
('', 'BytesType', ''),
|
(b'', 'BytesType', b''),
|
||||||
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807),
|
(b'\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807),
|
||||||
('\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808),
|
(b'\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808),
|
||||||
('', 'CounterColumnType', None),
|
(b'', 'CounterColumnType', None),
|
||||||
('\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)),
|
(b'\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)),
|
||||||
('', 'DateType', None),
|
(b'', 'DateType', None),
|
||||||
('\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')),
|
(b'\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')),
|
||||||
('\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')),
|
(b'\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')),
|
||||||
('\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')),
|
(b'\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')),
|
||||||
('\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')),
|
(b'\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')),
|
||||||
('\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')),
|
(b'\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')),
|
||||||
('', 'DecimalType', None),
|
(b'', 'DecimalType', None),
|
||||||
('@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125),
|
(b'@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125),
|
||||||
('\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125),
|
(b'\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125),
|
||||||
('\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308),
|
(b'\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308),
|
||||||
('', 'DoubleType', None),
|
(b'', 'DoubleType', None),
|
||||||
('F\x97\xd0@', 'FloatType', 19432.125),
|
(b'F\x97\xd0@', 'FloatType', 19432.125),
|
||||||
('\xc6\x97\xd0@', 'FloatType', -19432.125),
|
(b'\xc6\x97\xd0@', 'FloatType', -19432.125),
|
||||||
('\xc6\x97\xd0@', 'FloatType', -19432.125),
|
(b'\xc6\x97\xd0@', 'FloatType', -19432.125),
|
||||||
('\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0),
|
(b'\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0),
|
||||||
('', 'FloatType', None),
|
(b'', 'FloatType', None),
|
||||||
('\x7f\x50\x00\x00', 'Int32Type', 2135949312),
|
(b'\x7f\x50\x00\x00', 'Int32Type', 2135949312),
|
||||||
('\xff\xfd\xcb\x91', 'Int32Type', -144495),
|
(b'\xff\xfd\xcb\x91', 'Int32Type', -144495),
|
||||||
('', 'Int32Type', None),
|
(b'', 'Int32Type', None),
|
||||||
('f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789),
|
(b'f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789),
|
||||||
('', 'IntegerType', None),
|
(b'', 'IntegerType', None),
|
||||||
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807),
|
(b'\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807),
|
||||||
('\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808),
|
(b'\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808),
|
||||||
('', 'LongType', None),
|
(b'', 'LongType', None),
|
||||||
('', 'InetAddressType', None),
|
(b'', 'InetAddressType', None),
|
||||||
('A46\xa9', 'InetAddressType', '65.52.54.169'),
|
(b'A46\xa9', 'InetAddressType', '65.52.54.169'),
|
||||||
('*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'),
|
(b'*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'),
|
||||||
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'),
|
(b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'),
|
||||||
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000),
|
(b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000),
|
||||||
('', 'UTF8Type', u''),
|
(b'', 'UTF8Type', u''),
|
||||||
('\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')),
|
(b'\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')),
|
||||||
('I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')),
|
(b'I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')),
|
||||||
('', 'UUIDType', None),
|
(b'', 'UUIDType', None),
|
||||||
('', 'MapType(AsciiType, BooleanType)', None),
|
(b'', 'MapType(AsciiType, BooleanType)', None),
|
||||||
('', 'ListType(FloatType)', None),
|
(b'', 'ListType(FloatType)', None),
|
||||||
('', 'SetType(LongType)', None),
|
(b'', 'SetType(LongType)', None),
|
||||||
('\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()),
|
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()),
|
||||||
('\x00\x00', 'ListType(FloatType)', ()),
|
(b'\x00\x00', 'ListType(FloatType)', ()),
|
||||||
('\x00\x00', 'SetType(IntegerType)', sortedset()),
|
(b'\x00\x00', 'SetType(IntegerType)', sortedset()),
|
||||||
('\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', (UUID(bytes='\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'),)),
|
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', (UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'),)),
|
||||||
)
|
)
|
||||||
|
|
||||||
ordered_dict_value = OrderedDict()
|
ordered_dict_value = OrderedDict()
|
||||||
@@ -94,9 +94,9 @@ ordered_dict_value[u'\\'] = 0
|
|||||||
# these following entries work for me right now, but they're dependent on
|
# these following entries work for me right now, but they're dependent on
|
||||||
# vagaries of internal python ordering for unordered types
|
# vagaries of internal python ordering for unordered types
|
||||||
marshalled_value_pairs_unsafe = (
|
marshalled_value_pairs_unsafe = (
|
||||||
('\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value),
|
(b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value),
|
||||||
('\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])),
|
(b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])),
|
||||||
('\x00', 'IntegerType', 0),
|
(b'\x00', 'IntegerType', 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
if platform.python_implementation() == 'CPython':
|
if platform.python_implementation() == 'CPython':
|
||||||
|
|||||||
@@ -29,6 +29,12 @@ from cassandra.pool import Host
|
|||||||
|
|
||||||
class TestStrategies(unittest.TestCase):
|
class TestStrategies(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"Hook method for setting up class fixture before running tests in the class."
|
||||||
|
if not hasattr(cls, 'assertItemsEqual'):
|
||||||
|
cls.assertItemsEqual = cls.assertCountEqual
|
||||||
|
|
||||||
def test_replication_strategy(self):
|
def test_replication_strategy(self):
|
||||||
"""
|
"""
|
||||||
Basic code coverage testing that ensures different ReplicationStrategies
|
Basic code coverage testing that ensures different ReplicationStrategies
|
||||||
@@ -217,22 +223,22 @@ class TestTokens(unittest.TestCase):
|
|||||||
murmur3_token = Murmur3Token(cassandra.metadata.MIN_LONG - 1)
|
murmur3_token = Murmur3Token(cassandra.metadata.MIN_LONG - 1)
|
||||||
self.assertEqual(murmur3_token.hash_fn('123'), -7468325962851647638)
|
self.assertEqual(murmur3_token.hash_fn('123'), -7468325962851647638)
|
||||||
self.assertEqual(murmur3_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 7162290910810015547)
|
self.assertEqual(murmur3_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 7162290910810015547)
|
||||||
self.assertEqual(str(murmur3_token), '<Murmur3Token: -9223372036854775809L>')
|
self.assertEqual(str(murmur3_token), '<Murmur3Token: -9223372036854775809>')
|
||||||
except NoMurmur3:
|
except NoMurmur3:
|
||||||
raise unittest.SkipTest('The murmur3 extension is not available')
|
raise unittest.SkipTest('The murmur3 extension is not available')
|
||||||
|
|
||||||
def test_md5_tokens(self):
|
def test_md5_tokens(self):
|
||||||
md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1)
|
md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1)
|
||||||
self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808L)
|
self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808)
|
||||||
self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639L)
|
self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639)
|
||||||
self.assertEqual(str(md5_token), '<MD5Token: -9223372036854775809L>')
|
self.assertEqual(str(md5_token), '<MD5Token: %s>' % -9223372036854775809)
|
||||||
|
|
||||||
def test_bytes_tokens(self):
|
def test_bytes_tokens(self):
|
||||||
bytes_token = BytesToken(str(cassandra.metadata.MIN_LONG - 1))
|
bytes_token = BytesToken(str(cassandra.metadata.MIN_LONG - 1))
|
||||||
self.assertEqual(bytes_token.hash_fn('123'), '123')
|
self.assertEqual(bytes_token.hash_fn('123'), '123')
|
||||||
self.assertEqual(bytes_token.hash_fn(123), 123)
|
self.assertEqual(bytes_token.hash_fn(123), 123)
|
||||||
self.assertEqual(bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)), str(cassandra.metadata.MAX_LONG))
|
self.assertEqual(bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)), str(cassandra.metadata.MAX_LONG))
|
||||||
self.assertEqual(str(bytes_token), "<BytesToken: '-9223372036854775809'>")
|
self.assertEqual(str(bytes_token), "<BytesToken: -9223372036854775809>")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
bytes_token = BytesToken(cassandra.metadata.MIN_LONG - 1)
|
bytes_token = BytesToken(cassandra.metadata.MIN_LONG - 1)
|
||||||
|
|||||||
@@ -22,32 +22,34 @@ from cassandra.query import PreparedStatement, BoundStatement
|
|||||||
from cassandra.cqltypes import Int32Type
|
from cassandra.cqltypes import Int32Type
|
||||||
from cassandra.util import OrderedDict
|
from cassandra.util import OrderedDict
|
||||||
|
|
||||||
|
from six.moves import xrange
|
||||||
|
|
||||||
|
|
||||||
class ParamBindingTest(unittest.TestCase):
|
class ParamBindingTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_bind_sequence(self):
|
def test_bind_sequence(self):
|
||||||
result = bind_params("%s %s %s", (1, "a", 2.0))
|
result = bind_params("%s %s %s", (1, "a", 2.0))
|
||||||
self.assertEquals(result, "1 'a' 2.0")
|
self.assertEqual(result, "1 'a' 2.0")
|
||||||
|
|
||||||
def test_bind_map(self):
|
def test_bind_map(self):
|
||||||
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0))
|
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0))
|
||||||
self.assertEquals(result, "1 'a' 2.0")
|
self.assertEqual(result, "1 'a' 2.0")
|
||||||
|
|
||||||
def test_sequence_param(self):
|
def test_sequence_param(self):
|
||||||
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),))
|
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),))
|
||||||
self.assertEquals(result, "( 1 , 'a' , 2.0 )")
|
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
|
||||||
|
|
||||||
def test_generator_param(self):
|
def test_generator_param(self):
|
||||||
result = bind_params("%s", ((i for i in xrange(3)),))
|
result = bind_params("%s", ((i for i in xrange(3)),))
|
||||||
self.assertEquals(result, "[ 0 , 1 , 2 ]")
|
self.assertEqual(result, "[ 0 , 1 , 2 ]")
|
||||||
|
|
||||||
def test_none_param(self):
|
def test_none_param(self):
|
||||||
result = bind_params("%s", (None,))
|
result = bind_params("%s", (None,))
|
||||||
self.assertEquals(result, "NULL")
|
self.assertEqual(result, "NULL")
|
||||||
|
|
||||||
def test_list_collection(self):
|
def test_list_collection(self):
|
||||||
result = bind_params("%s", (['a', 'b', 'c'],))
|
result = bind_params("%s", (['a', 'b', 'c'],))
|
||||||
self.assertEquals(result, "[ 'a' , 'b' , 'c' ]")
|
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
|
||||||
|
|
||||||
def test_set_collection(self):
|
def test_set_collection(self):
|
||||||
result = bind_params("%s", (set(['a', 'b']),))
|
result = bind_params("%s", (set(['a', 'b']),))
|
||||||
@@ -59,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
|
|||||||
vals['b'] = 'b'
|
vals['b'] = 'b'
|
||||||
vals['c'] = 'c'
|
vals['c'] = 'c'
|
||||||
result = bind_params("%s", (vals,))
|
result = bind_params("%s", (vals,))
|
||||||
self.assertEquals(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
|
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
|
||||||
|
|
||||||
def test_quote_escaping(self):
|
def test_quote_escaping(self):
|
||||||
result = bind_params("%s", ("""'ef''ef"ef""ef'""",))
|
result = bind_params("%s", ("""'ef''ef"ef""ef'""",))
|
||||||
self.assertEquals(result, """'''ef''''ef"ef""ef'''""")
|
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")
|
||||||
|
|
||||||
|
|
||||||
class BoundStatementTestCase(unittest.TestCase):
|
class BoundStatementTestCase(unittest.TestCase):
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ except ImportError:
|
|||||||
from itertools import islice, cycle
|
from itertools import islice, cycle
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
from random import randint
|
from random import randint
|
||||||
|
import six
|
||||||
import sys
|
import sys
|
||||||
import struct
|
import struct
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@@ -36,6 +37,8 @@ from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
|
|||||||
from cassandra.pool import Host
|
from cassandra.pool import Host
|
||||||
from cassandra.query import Statement
|
from cassandra.query import Statement
|
||||||
|
|
||||||
|
from six.moves import xrange
|
||||||
|
|
||||||
|
|
||||||
class TestLoadBalancingPolicy(unittest.TestCase):
|
class TestLoadBalancingPolicy(unittest.TestCase):
|
||||||
def test_non_implemented(self):
|
def test_non_implemented(self):
|
||||||
@@ -137,13 +140,23 @@ class TestRoundRobinPolicy(unittest.TestCase):
|
|||||||
|
|
||||||
# make the GIL switch after every instruction, maximizing
|
# make the GIL switch after every instruction, maximizing
|
||||||
# the chace of race conditions
|
# the chace of race conditions
|
||||||
original_interval = sys.getcheckinterval()
|
if six.PY2:
|
||||||
|
original_interval = sys.getcheckinterval()
|
||||||
|
else:
|
||||||
|
original_interval = sys.getswitchinterval()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sys.setcheckinterval(0)
|
if six.PY2:
|
||||||
|
sys.setcheckinterval(0)
|
||||||
|
else:
|
||||||
|
sys.setswitchinterval(0.0001)
|
||||||
map(lambda t: t.start(), threads)
|
map(lambda t: t.start(), threads)
|
||||||
map(lambda t: t.join(), threads)
|
map(lambda t: t.join(), threads)
|
||||||
finally:
|
finally:
|
||||||
sys.setcheckinterval(original_interval)
|
if six.PY2:
|
||||||
|
sys.setcheckinterval(original_interval)
|
||||||
|
else:
|
||||||
|
sys.setswitchinterval(original_interval)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
self.fail("Saw errors: %s" % (errors,))
|
self.fail("Saw errors: %s" % (errors,))
|
||||||
@@ -334,14 +347,14 @@ class TokenAwarePolicyTest(unittest.TestCase):
|
|||||||
|
|
||||||
replicas = get_replicas(None, struct.pack('>i', i))
|
replicas = get_replicas(None, struct.pack('>i', i))
|
||||||
other = set(h for h in hosts if h not in replicas)
|
other = set(h for h in hosts if h not in replicas)
|
||||||
self.assertEquals(replicas, qplan[:2])
|
self.assertEqual(replicas, qplan[:2])
|
||||||
self.assertEquals(other, set(qplan[2:]))
|
self.assertEqual(other, set(qplan[2:]))
|
||||||
|
|
||||||
# Should use the secondary policy
|
# Should use the secondary policy
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
qplan = list(policy.make_query_plan())
|
qplan = list(policy.make_query_plan())
|
||||||
|
|
||||||
self.assertEquals(set(qplan), set(hosts))
|
self.assertEqual(set(qplan), set(hosts))
|
||||||
|
|
||||||
def test_wrap_dc_aware(self):
|
def test_wrap_dc_aware(self):
|
||||||
cluster = Mock(spec=Cluster)
|
cluster = Mock(spec=Cluster)
|
||||||
@@ -374,16 +387,16 @@ class TokenAwarePolicyTest(unittest.TestCase):
|
|||||||
|
|
||||||
# first should be the only local replica
|
# first should be the only local replica
|
||||||
self.assertIn(qplan[0], replicas)
|
self.assertIn(qplan[0], replicas)
|
||||||
self.assertEquals(qplan[0].datacenter, "dc1")
|
self.assertEqual(qplan[0].datacenter, "dc1")
|
||||||
|
|
||||||
# then the local non-replica
|
# then the local non-replica
|
||||||
self.assertNotIn(qplan[1], replicas)
|
self.assertNotIn(qplan[1], replicas)
|
||||||
self.assertEquals(qplan[1].datacenter, "dc1")
|
self.assertEqual(qplan[1].datacenter, "dc1")
|
||||||
|
|
||||||
# then one of the remotes (used_hosts_per_remote_dc is 1, so we
|
# then one of the remotes (used_hosts_per_remote_dc is 1, so we
|
||||||
# shouldn't see two remotes)
|
# shouldn't see two remotes)
|
||||||
self.assertEquals(qplan[2].datacenter, "dc2")
|
self.assertEqual(qplan[2].datacenter, "dc2")
|
||||||
self.assertEquals(3, len(qplan))
|
self.assertEqual(3, len(qplan))
|
||||||
|
|
||||||
class FakeCluster:
|
class FakeCluster:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ class ResponseFutureTests(unittest.TestCase):
|
|||||||
rf.send_request()
|
rf.send_request()
|
||||||
|
|
||||||
rf.add_callbacks(
|
rf.add_callbacks(
|
||||||
callback=self.assertEquals, callback_args=([{'col': 'val'}],),
|
callback=self.assertEqual, callback_args=([{'col': 'val'}],),
|
||||||
errback=self.assertIsInstance, errback_args=(Exception,))
|
errback=self.assertIsInstance, errback_args=(Exception,))
|
||||||
|
|
||||||
result = Mock(spec=UnavailableErrorMessage, info={})
|
result = Mock(spec=UnavailableErrorMessage, info={})
|
||||||
@@ -358,7 +358,7 @@ class ResponseFutureTests(unittest.TestCase):
|
|||||||
rf.send_request()
|
rf.send_request()
|
||||||
|
|
||||||
rf.add_callbacks(
|
rf.add_callbacks(
|
||||||
callback=self.assertEquals, callback_args=([{'col': 'val'}],),
|
callback=self.assertEqual, callback_args=([{'col': 'val'}],),
|
||||||
errback=self.assertIsInstance, errback_args=(Exception,))
|
errback=self.assertIsInstance, errback_args=(Exception,))
|
||||||
|
|
||||||
rf._set_result(self.make_mock_response([{'col': 'val'}]))
|
rf._set_result(self.make_mock_response([{'col': 'val'}]))
|
||||||
@@ -380,9 +380,9 @@ class ResponseFutureTests(unittest.TestCase):
|
|||||||
|
|
||||||
session.submit.assert_called_once()
|
session.submit.assert_called_once()
|
||||||
args, kwargs = session.submit.call_args
|
args, kwargs = session.submit.call_args
|
||||||
self.assertEquals(rf._reprepare, args[-2])
|
self.assertEqual(rf._reprepare, args[-2])
|
||||||
self.assertIsInstance(args[-1], PrepareMessage)
|
self.assertIsInstance(args[-1], PrepareMessage)
|
||||||
self.assertEquals(args[-1].query, "SELECT * FROM foobar")
|
self.assertEqual(args[-1].query, "SELECT * FROM foobar")
|
||||||
|
|
||||||
def test_prepared_query_not_found_bad_keyspace(self):
|
def test_prepared_query_not_found_bad_keyspace(self):
|
||||||
session = self.make_session()
|
session = self.make_session()
|
||||||
|
|||||||
@@ -163,18 +163,17 @@ class TypeTests(unittest.TestCase):
|
|||||||
'7a6970:org.apache.cassandra.db.marshal.UTF8Type',
|
'7a6970:org.apache.cassandra.db.marshal.UTF8Type',
|
||||||
')')))
|
')')))
|
||||||
|
|
||||||
self.assertEquals(FooType, ctype.__class__)
|
self.assertEqual(FooType, ctype.__class__)
|
||||||
|
|
||||||
self.assertEquals(UTF8Type, ctype.subtypes[0])
|
self.assertEqual(UTF8Type, ctype.subtypes[0])
|
||||||
|
|
||||||
# middle subtype should be a BarType instance with its own subtypes and names
|
# middle subtype should be a BarType instance with its own subtypes and names
|
||||||
self.assertIsInstance(ctype.subtypes[1], BarType)
|
self.assertIsInstance(ctype.subtypes[1], BarType)
|
||||||
self.assertEquals([UTF8Type], ctype.subtypes[1].subtypes)
|
self.assertEqual([UTF8Type], ctype.subtypes[1].subtypes)
|
||||||
self.assertEquals(["address"], ctype.subtypes[1].names)
|
self.assertEqual([b"address"], ctype.subtypes[1].names)
|
||||||
|
|
||||||
self.assertEquals(UTF8Type, ctype.subtypes[2])
|
self.assertEqual(UTF8Type, ctype.subtypes[2])
|
||||||
|
self.assertEqual([b'city', None, b'zip'], ctype.names)
|
||||||
self.assertEquals(['city', None, 'zip'], ctype.names)
|
|
||||||
|
|
||||||
def test_empty_value(self):
|
def test_empty_value(self):
|
||||||
self.assertEqual(str(EmptyValue()), 'EMPTY')
|
self.assertEqual(str(EmptyValue()), 'EMPTY')
|
||||||
|
|||||||
10
tox.ini
10
tox.ini
@@ -1,5 +1,5 @@
|
|||||||
[tox]
|
[tox]
|
||||||
envlist = py26,py27,pypy
|
envlist = py26,py27,pypy,py33
|
||||||
|
|
||||||
[testenv]
|
[testenv]
|
||||||
deps = nose
|
deps = nose
|
||||||
@@ -8,5 +8,13 @@ deps = nose
|
|||||||
unittest2
|
unittest2
|
||||||
pip
|
pip
|
||||||
PyYAML
|
PyYAML
|
||||||
|
six
|
||||||
commands = {envpython} setup.py build_ext --inplace
|
commands = {envpython} setup.py build_ext --inplace
|
||||||
nosetests --verbosity=2 tests/unit/
|
nosetests --verbosity=2 tests/unit/
|
||||||
|
|
||||||
|
[testenv:py33]
|
||||||
|
deps = nose
|
||||||
|
mock
|
||||||
|
pip
|
||||||
|
PyYAML
|
||||||
|
six
|
||||||
|
|||||||
Reference in New Issue
Block a user