Start fixing integration tests for python 3

This commit is contained in:
Tyler Hobbs
2014-04-03 19:10:49 -05:00
parent 03b8e08653
commit 52eb86d2de
7 changed files with 35 additions and 30 deletions

View File

@@ -1,6 +1,7 @@
import sys import sys
from itertools import count, cycle from itertools import count, cycle
from six.moves import xrange
from threading import Event from threading import Event
@@ -79,7 +80,7 @@ def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs
parameters = [(x,) for x in range(1000)] parameters = [(x,) for x in range(1000)]
execute_concurrent_with_args(session, statement, parameters) execute_concurrent_with_args(session, statement, parameters)
""" """
return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs) return execute_concurrent(session, list(zip(cycle((statement,)), parameters)), *args, **kwargs)
_sentinel = object() _sentinel = object()
@@ -92,12 +93,12 @@ def _handle_error(error, result_index, event, session, statements, results, num_
return return
else: else:
results[result_index] = (False, error) results[result_index] = (False, error)
if num_finished.next() >= to_execute: if next(num_finished) >= to_execute:
event.set() event.set()
return return
try: try:
(next_index, (statement, params)) = statements.next() (next_index, (statement, params)) = next(statements)
except StopIteration: except StopIteration:
return return
@@ -113,7 +114,7 @@ def _handle_error(error, result_index, event, session, statements, results, num_
return return
else: else:
results[next_index] = (False, exc) results[next_index] = (False, exc)
if num_finished.next() >= to_execute: if next(num_finished) >= to_execute:
event.set() event.set()
return return
@@ -121,13 +122,13 @@ def _handle_error(error, result_index, event, session, statements, results, num_
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error): def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
if result is not _sentinel: if result is not _sentinel:
results[result_index] = (True, result) results[result_index] = (True, result)
finished = num_finished.next() finished = next(num_finished)
if finished >= to_execute: if finished >= to_execute:
event.set() event.set()
return return
try: try:
(next_index, (statement, params)) = statements.next() (next_index, (statement, params)) = next(statements)
except StopIteration: except StopIteration:
return return
@@ -143,6 +144,6 @@ def _execute_next(result, result_index, event, session, statements, results, num
return return
else: else:
results[next_index] = (False, exc) results[next_index] = (False, exc)
if num_finished.next() >= to_execute: if next(num_finished) >= to_execute:
event.set() event.set()
return return

View File

@@ -57,7 +57,7 @@ def trim_if_startswith(s, prefix):
def unix_time_from_uuid1(u): def unix_time_from_uuid1(u):
return (u.get_time() - 0x01B21DD213814000) / 10000000.0 return (u.time - 0x01B21DD213814000) / 10000000.0
_casstypes = {} _casstypes = {}
@@ -318,7 +318,7 @@ if six.PY3:
(_UnrecognizedType,), (_UnrecognizedType,),
{'typename': "'%s'" % casstypename}) {'typename': "'%s'" % casstypename})
else: else:
def mkUnrecognizedType(casstypename): def mkUnrecognizedType(casstypename): # noqa
return CassandraTypeType(casstypename.encode('utf8'), return CassandraTypeType(casstypename.encode('utf8'),
(_UnrecognizedType,), (_UnrecognizedType,),
{'typename': "'%s'" % casstypename}) {'typename': "'%s'" % casstypename})

View File

@@ -5,7 +5,9 @@ except ImportError:
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
import os import os
from six import print_
from threading import Event from threading import Event
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
@@ -74,7 +76,7 @@ def get_node(node_id):
def setup_package(): def setup_package():
print 'Using Cassandra version: %s' % CASSANDRA_VERSION print_('Using Cassandra version: %s' % CASSANDRA_VERSION)
try: try:
try: try:
cluster = CCMCluster.load(path, CLUSTER_NAME) cluster = CCMCluster.load(path, CLUSTER_NAME)

View File

@@ -29,7 +29,7 @@ class ClusterTests(unittest.TestCase):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", )) statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(num_statements)] parameters = [(i, i) for i in range(num_statements)]
results = execute_concurrent(self.session, zip(statements, parameters)) results = execute_concurrent(self.session, list(zip(statements, parameters)))
self.assertEqual(num_statements, len(results)) self.assertEqual(num_statements, len(results))
self.assertEqual([(True, None)] * num_statements, results) self.assertEqual([(True, None)] * num_statements, results)
@@ -37,7 +37,7 @@ class ClusterTests(unittest.TestCase):
statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", )) statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", ))
parameters = [(i, ) for i in range(num_statements)] parameters = [(i, ) for i in range(num_statements)]
results = execute_concurrent(self.session, zip(statements, parameters)) results = execute_concurrent(self.session, list(zip(statements, parameters)))
self.assertEqual(num_statements, len(results)) self.assertEqual(num_statements, len(results))
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results) self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
@@ -67,7 +67,7 @@ class ClusterTests(unittest.TestCase):
self.assertRaises( self.assertRaises(
InvalidRequest, InvalidRequest,
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True) execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)
def test_first_failure_client_side(self): def test_first_failure_client_side(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", )) statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
@@ -78,7 +78,7 @@ class ClusterTests(unittest.TestCase):
self.assertRaises( self.assertRaises(
TypeError, TypeError,
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True) execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)
def test_no_raise_on_first_failure(self): def test_no_raise_on_first_failure(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", )) statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
@@ -87,7 +87,7 @@ class ClusterTests(unittest.TestCase):
# we'll get an error back from the server # we'll get an error back from the server
parameters[57] = ('efefef', 'awefawefawef') parameters[57] = ('efefef', 'awefawefawef')
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False) results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
for i, (success, result) in enumerate(results): for i, (success, result) in enumerate(results):
if i == 57: if i == 57:
self.assertFalse(success) self.assertFalse(success)
@@ -101,9 +101,9 @@ class ClusterTests(unittest.TestCase):
parameters = [(i, i) for i in range(100)] parameters = [(i, i) for i in range(100)]
# the driver will raise an error when binding the params # the driver will raise an error when binding the params
parameters[57] = i parameters[57] = 1
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False) results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
for i, (success, result) in enumerate(results): for i, (success, result) in enumerate(results):
if i == 57: if i == 57:
self.assertFalse(success) self.assertFalse(success)

View File

@@ -10,7 +10,7 @@ from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.policies import HostDistance from cassandra.policies import HostDistance
from tests.integration import get_server_versions, PROTOCOL_VERSION from tests.integration import PROTOCOL_VERSION
class QueryTest(unittest.TestCase): class QueryTest(unittest.TestCase):
@@ -29,7 +29,7 @@ class QueryTest(unittest.TestCase):
self.assertIsInstance(bound, BoundStatement) self.assertIsInstance(bound, BoundStatement)
self.assertEqual(2, len(bound.values)) self.assertEqual(2, len(bound.values))
session.execute(bound) session.execute(bound)
self.assertEqual(bound.routing_key, '\x00\x00\x00\x01') self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_value_sequence(self): def test_value_sequence(self):
""" """
@@ -88,7 +88,7 @@ class PreparedStatementTests(unittest.TestCase):
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, None)) bound = prepared.bind((1, None))
self.assertEqual(bound.routing_key, '\x00\x00\x00\x01') self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_empty_routing_key_indexes(self): def test_empty_routing_key_indexes(self):
""" """
@@ -144,7 +144,7 @@ class PreparedStatementTests(unittest.TestCase):
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, 2)) bound = prepared.bind((1, 2))
self.assertEqual(bound.routing_key, '\x04\x00\x00\x00\x04\x00\x00\x00') self.assertEqual(bound.routing_key, b'\x04\x00\x00\x00\x04\x00\x00\x00')
def test_bound_keyspace(self): def test_bound_keyspace(self):
""" """

View File

@@ -9,6 +9,7 @@ except ImportError:
import unittest # noqa import unittest # noqa
from itertools import cycle, count from itertools import cycle, count
from six.moves import range
from threading import Event from threading import Event
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
@@ -33,7 +34,7 @@ class QueryPagingTests(unittest.TestCase):
def test_paging(self): def test_paging(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)]) [(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params) execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test") prepared = self.session.prepare("SELECT * FROM test3rf.test")
@@ -49,7 +50,7 @@ class QueryPagingTests(unittest.TestCase):
def test_async_paging(self): def test_async_paging(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)]) [(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params) execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test") prepared = self.session.prepare("SELECT * FROM test3rf.test")
@@ -65,7 +66,7 @@ class QueryPagingTests(unittest.TestCase):
def test_paging_callbacks(self): def test_paging_callbacks(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)]) [(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params) execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test") prepared = self.session.prepare("SELECT * FROM test3rf.test")
@@ -78,7 +79,7 @@ class QueryPagingTests(unittest.TestCase):
def handle_page(rows, future, counter): def handle_page(rows, future, counter):
for row in rows: for row in rows:
counter.next() next(counter)
if future.has_more_pages: if future.has_more_pages:
future.start_fetching_next_page() future.start_fetching_next_page()
@@ -91,7 +92,7 @@ class QueryPagingTests(unittest.TestCase):
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error) future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait() event.wait()
self.assertEquals(counter.next(), 100) self.assertEquals(next(counter), 100)
# simple statement # simple statement
future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test")) future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
@@ -100,7 +101,7 @@ class QueryPagingTests(unittest.TestCase):
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error) future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait() event.wait()
self.assertEquals(counter.next(), 100) self.assertEquals(next(counter), 100)
# prepared statement # prepared statement
future = self.session.execute_async(prepared) future = self.session.execute_async(prepared)
@@ -109,4 +110,4 @@ class QueryPagingTests(unittest.TestCase):
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error) future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait() event.wait()
self.assertEquals(counter.next(), 100) self.assertEquals(next(counter), 100)

View File

@@ -3,6 +3,7 @@ try:
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
import binascii
from decimal import Decimal from decimal import Decimal
from datetime import datetime from datetime import datetime
from uuid import uuid1, uuid4 from uuid import uuid1, uuid4
@@ -44,8 +45,8 @@ class TypeTests(unittest.TestCase):
""") """)
params = [ params = [
'key1', b'key1',
'blobyblob'.encode('hex') binascii.hexlify(b'blobyblob')
] ]
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s)' query = 'INSERT INTO mytable (a, b) VALUES (%s, %s)'