From 4b1b50d929658713ab81bf06da13da5072e2424b Mon Sep 17 00:00:00 2001 From: Joaquin Casares Date: Mon, 16 Sep 2013 20:39:16 -0500 Subject: [PATCH] Added all extra tests --- .gitignore | 1 + tests/integration/__init__.py | 31 ++ tests/integration/test_cluster.py | 167 ++++++++++ tests/integration/test_factories.py | 124 ++++++++ tests/integration/test_metadata.py | 101 ++++++ tests/integration/test_metrics.py | 137 +++++++++ tests/integration/test_prepared_statements.py | 93 ++++++ tests/integration/test_query.py | 45 +++ tests/unit/test_host_connection_pool.py | 26 +- tests/unit/test_metadata.py | 101 ++++++ tests/unit/test_policies.py | 193 +++++++++++- tests/unit/test_types.py | 289 ++++++++++++++++++ 12 files changed, 1302 insertions(+), 6 deletions(-) create mode 100644 tests/integration/test_factories.py create mode 100644 tests/integration/test_metrics.py create mode 100644 tests/integration/test_query.py create mode 100644 tests/unit/test_metadata.py create mode 100644 tests/unit/test_types.py diff --git a/.gitignore b/.gitignore index c4fca1c5..ffe192c5 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ build MANIFEST dist .coverage +nosetests.xml cover/ docs/_build/ tests/integration/ccm diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 242e4292..83a971b2 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -8,6 +8,8 @@ log = logging.getLogger(__name__) import os from threading import Event +from cassandra.cluster import Cluster + try: from ccmlib.cluster import Cluster as CCMCluster from ccmlib import common @@ -25,6 +27,8 @@ if not os.path.exists(path): def get_cluster(): return CCM_CLUSTER +def get_node(node_id): + return CCM_CLUSTER.nodes['node%s' % node_id] def setup_package(): try: @@ -47,6 +51,33 @@ def setup_package(): global CCM_CLUSTER CCM_CLUSTER = cluster + setup_test_keyspace() + +def setup_test_keyspace(): + cluster = Cluster() + session = cluster.connect() + + ksname = 'test3rf' + cfname = 'test' + + try: + results = session.execute("SELECT keyspace_name FROM system.schema_keyspaces") + existing_keyspaces = [row[0] for row in results] + if ksname in existing_keyspaces: + session.execute("DROP KEYSPACE %s" % ksname) + + ddl = ''' + CREATE KEYSPACE %s + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}''' + session.execute(ddl % ksname) + + ddl = ''' + CREATE TABLE %s.%s ( + k int PRIMARY KEY, + v int )''' + session.execute(ddl % (ksname, cfname)) + finally: + cluster.shutdown() def teardown_package(): diff --git a/tests/integration/test_cluster.py b/tests/integration/test_cluster.py index a20a78d0..ac9c1342 100644 --- a/tests/integration/test_cluster.py +++ b/tests/integration/test_cluster.py @@ -3,11 +3,20 @@ try: except ImportError: import unittest # noqa +import cassandra +from cassandra.query import SimpleStatement +from cassandra.io.asyncorereactor import AsyncoreConnection +from cassandra.policies import RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance + from cassandra.cluster import Cluster, NoHostAvailable class ClusterTests(unittest.TestCase): def test_basic(self): + """ + Test basic connection and usage + """ + cluster = Cluster() session = cluster.connect() result = session.execute( @@ -39,11 +48,117 @@ class ClusterTests(unittest.TestCase): cluster.shutdown() + def test_connect_on_keyspace(self): + """ + Ensure clusters that connect on a keyspace, do + """ + + cluster = Cluster() + session = cluster.connect() + result = session.execute( + """ + INSERT INTO test3rf.test (k, v) VALUES (8889, 8889) + """) + self.assertEquals(None, result) + + result = session.execute("SELECT * FROM test3rf.test") + self.assertEquals([(8889, 8889)], result) + + # test_connect_on_keyspace + session2 = cluster.connect('test3rf') + result2 = session2.execute("SELECT * FROM test") + self.assertEquals(result, result2) + + def test_default_connections(self): + """ + Ensure errors are not thrown when using non-default policies + """ + + cluster = Cluster( + load_balancing_policy=RoundRobinPolicy(), + reconnection_policy=ExponentialReconnectionPolicy(1.0, 600.0), + default_retry_policy=RetryPolicy(), + conviction_policy_factory=SimpleConvictionPolicy, + connection_class=AsyncoreConnection + ) + + def test_double_shutdown(self): + """ + Ensure that a cluster can be shutdown twice, without error + """ + + # DISCUSS: Should we allow this? + + cluster = Cluster() + cluster.shutdown() + + # Shouldn't throw an error + cluster.shutdown() + + def test_connect_to_already_shutdown_cluster(self): + """ + Ensure you cannot connect to a cluster that's been shutdown + """ + + cluster = Cluster() + cluster.shutdown() + self.assertRaises(Exception, cluster.connect) + + def test_auth_provider_is_callable(self): + """ + Ensure that auth_providers are always callable + """ + + self.assertRaises(ValueError, Cluster, auth_provider=1) + + def test_conviction_policy_factory_is_callable(self): + """ + Ensure that conviction_policy_factory are always callable + """ + + self.assertRaises(ValueError, Cluster, conviction_policy_factory=1) + def test_connect_to_bad_hosts(self): + """ + Ensure that a NoHostAvailable Exception is thrown + when a cluster cannot connect to given hosts + """ + cluster = Cluster(['127.1.2.9', '127.1.2.10']) self.assertRaises(NoHostAvailable, cluster.connect) + def test_cluster_settings(self): + """ + Test connection setting getters and setters + """ + + cluster = Cluster() + + min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL) + self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection) + cluster.set_min_requests_per_connection(HostDistance.LOCAL, min_requests_per_connection + 1) + self.assertEqual(cluster.get_min_requests_per_connection(HostDistance.LOCAL), min_requests_per_connection + 1) + + max_requests_per_connection = cluster.get_max_requests_per_connection(HostDistance.LOCAL) + self.assertEqual(cassandra.cluster.DEFAULT_MAX_REQUESTS, max_requests_per_connection) + cluster.set_max_requests_per_connection(HostDistance.LOCAL, max_requests_per_connection + 1) + self.assertEqual(cluster.get_max_requests_per_connection(HostDistance.LOCAL), max_requests_per_connection + 1) + + core_connections_per_host = cluster.get_core_connections_per_host(HostDistance.LOCAL) + self.assertEqual(cassandra.cluster.DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, core_connections_per_host) + cluster.set_core_connections_per_host(HostDistance.LOCAL, core_connections_per_host + 1) + self.assertEqual(cluster.get_core_connections_per_host(HostDistance.LOCAL), core_connections_per_host + 1) + + max_connections_per_host = cluster.get_max_connections_per_host(HostDistance.LOCAL) + self.assertEqual(cassandra.cluster.DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, max_connections_per_host) + cluster.set_max_connections_per_host(HostDistance.LOCAL, max_connections_per_host + 1) + self.assertEqual(cluster.get_max_connections_per_host(HostDistance.LOCAL), max_connections_per_host + 1) + def test_submit_schema_refresh(self): + """ + Ensure new new schema is refreshed after submit_schema_refresh() + """ + cluster = Cluster() cluster.connect() self.assertNotIn("newkeyspace", cluster.metadata.keyspaces) @@ -62,6 +177,10 @@ class ClusterTests(unittest.TestCase): self.assertIn("newkeyspace", cluster.metadata.keyspaces) def test_on_down_and_up(self): + """ + Test on_down and on_up handling + """ + cluster = Cluster() session = cluster.connect() host = cluster.metadata.all_hosts()[0] @@ -78,3 +197,51 @@ class ClusterTests(unittest.TestCase): self.assertEqual(None, host._reconnection_handler) self.assertTrue(host_reconnector._cancelled) self.assertIn(host, session._pools) + + def test_trace(self): + """ + Ensure trace can be requested for async and non-async queries + """ + + cluster = Cluster() + session = cluster.connect() + + self.assertRaises(TypeError, session.execute, "SELECT * FROM system.local", trace=True) + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + session.execute(statement, trace=True) + self.assertEqual(query, statement.trace.parameters['query']) + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + session.execute(statement) + self.assertEqual(None, statement.trace) + + statement2 = SimpleStatement(query) + future = session.execute_async(statement2, trace=True) + future.result() + self.assertEqual(query, future.get_query_trace().parameters['query']) + + statement2 = SimpleStatement(query) + future = session.execute_async(statement2) + future.result() + self.assertEqual(None, future.get_query_trace()) + + def test_string_coverage(self): + """ + Ensure str(future) returns without error + """ + + cluster = Cluster() + session = cluster.connect() + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + future = session.execute_async(statement) + + self.assertIn(query, str(future)) + future.result() + + self.assertIn(query, str(future)) + self.assertIn('result', str(future)) diff --git a/tests/integration/test_factories.py b/tests/integration/test_factories.py new file mode 100644 index 00000000..1cb9e178 --- /dev/null +++ b/tests/integration/test_factories.py @@ -0,0 +1,124 @@ +import unittest +import cassandra +from cassandra.cluster import Cluster +from cassandra.decoder import tuple_factory, named_tuple_factory, dict_factory, ordered_dict_factory + +try: + from collections import OrderedDict +except ImportError: # Python <2.7 + from cassandra.util import OrderedDict # NOQA + +class TestFactories(unittest.TestCase): + """ + Test different row_factories and access code + """ + + truncate = ''' + TRUNCATE test3rf.test + ''' + + insert1 = ''' + INSERT INTO test3rf.test + ( k , v ) + VALUES + ( 1 , 1 ) + ''' + + insert2 = ''' + INSERT INTO test3rf.test + ( k , v ) + VALUES + ( 2 , 2 ) + ''' + + select = ''' + SELECT * FROM test3rf.test + ''' + + def test_tuple_factory(self): + cluster = Cluster() + session = cluster.connect() + session.row_factory = tuple_factory + + session.execute(self.truncate) + session.execute(self.insert1) + session.execute(self.insert2) + + result = session.execute(self.select) + + self.assertTrue(isinstance(result, list)) + self.assertTrue(isinstance(result[0], tuple)) + + for row in result: + self.assertEqual(row[0], row[1]) + + self.assertEqual(result[0][0], result[0][1]) + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], result[1][1]) + self.assertEqual(result[1][0], 2) + + def test_named_tuple_factoryy(self): + cluster = Cluster() + session = cluster.connect() + session.row_factory = named_tuple_factory + + session.execute(self.truncate) + session.execute(self.insert1) + session.execute(self.insert2) + + result = session.execute(self.select) + + self.assertTrue(isinstance(result, list)) + + for row in result: + self.assertEqual(row.k, row.v) + + self.assertEqual(result[0].k, result[0].v) + self.assertEqual(result[0].k, 1) + self.assertEqual(result[1].k, result[1].v) + self.assertEqual(result[1].k, 2) + + + def test_dict_factory(self): + cluster = Cluster() + session = cluster.connect() + session.row_factory = dict_factory + + session.execute(self.truncate) + session.execute(self.insert1) + session.execute(self.insert2) + + result = session.execute(self.select) + + self.assertTrue(isinstance(result, list)) + self.assertTrue(isinstance(result[0], dict)) + + for row in result: + self.assertEqual(row['k'], row['v']) + + self.assertEqual(result[0]['k'], result[0]['v']) + self.assertEqual(result[0]['k'], 1) + self.assertEqual(result[1]['k'], result[1]['v']) + self.assertEqual(result[1]['k'], 2) + + def test_ordered_dict_factory(self): + cluster = Cluster() + session = cluster.connect() + session.row_factory = ordered_dict_factory + + session.execute(self.truncate) + session.execute(self.insert1) + session.execute(self.insert2) + + result = session.execute(self.select) + + self.assertTrue(isinstance(result, list)) + self.assertTrue(isinstance(result[0], OrderedDict)) + + for row in result: + self.assertEqual(row['k'], row['v']) + + self.assertEqual(result[0]['k'], result[0]['v']) + self.assertEqual(result[0]['k'], 1) + self.assertEqual(result[1]['k'], result[1]['v']) + self.assertEqual(result[1]['k'], 2) diff --git a/tests/integration/test_metadata.py b/tests/integration/test_metadata.py index 930474a2..0689d9c9 100644 --- a/tests/integration/test_metadata.py +++ b/tests/integration/test_metadata.py @@ -3,8 +3,12 @@ try: except ImportError: import unittest # noqa +import cassandra +from cassandra import AlreadyExists + from cassandra.cluster import Cluster from cassandra.metadata import KeyspaceMetadata, TableMetadata, Token, MD5Token, TokenMap +from cassandra.metadata import TableMetadata, Token, MD5Token, TokenMap, Murmur3Token from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host @@ -272,6 +276,103 @@ class SchemaMetadataTest(unittest.TestCase): self.assertEqual(d_index, statements[1]) self.assertEqual(e_index, statements[2]) +class TestCodeCoverage(unittest.TestCase): + def test_export_schema(self): + """ + Test export schema functionality + """ + + cluster = Cluster() + cluster.connect() + + # BUG: Does not work + # print cluster.metadata.export_schema_as_string() + # Traceback (most recent call last): + # File "/Users/joaquin/repos/python-driver/tests/integration/test_metadata.py", line 280, in test_export_schema + # print cluster.metadata.export_schema_as_string() + # File "/Users/joaquin/repos/python-driver/cassandra/metadata.py", line 71, in export_schema_as_string + # return "\n".join(ks.export_as_string() for ks in self.keyspaces.values()) + # File "/Users/joaquin/repos/python-driver/cassandra/metadata.py", line 71, in + # return "\n".join(ks.export_as_string() for ks in self.keyspaces.values()) + # File "/Users/joaquin/repos/python-driver/cassandra/metadata.py", line 351, in export_as_string + # return "\n".join([self.as_cql_query()] + [t.as_cql_query() for t in self.tables.values()]) + # TypeError: sequence item 0: expected string, NoneType found + + def test_export_keyspace_schema(self): + """ + Test export keyspace schema functionality + """ + + cluster = Cluster() + cluster.connect() + + # BUG: Attempting to check cassandra.metadata:350 + # print cluster.metadata.keyspaces.export_as_string() + # Traceback (most recent call last): + # File "/Users/joaquin/repos/python-driver/tests/integration/test_metadata.py", line 296, in test_export_keyspace_schema + # print cluster.metadata.keyspaces.export_as_string() + # AttributeError: 'dict' object has no attribute 'export_as_string' + + # BUG: Attempting to check cassandra.metadata:353 + # print cluster.metadata.keyspaces.as_cql_query() + # Traceback (most recent call last): + # File "/Users/joaquin/repos/python-driver/tests/integration/test_metadata.py", line 305, in test_export_keyspace_schema + # print cluster.metadata.keyspaces.as_cql_query() + # AttributeError: 'dict' object has no attribute 'as_cql_query' + + def test_already_exists_exceptions(self): + """ + Ensure AlreadyExists exception is thrown when hit + """ + + cluster = Cluster() + session = cluster.connect() + + ksname = 'test3rf' + cfname = 'test' + + ddl = ''' + CREATE KEYSPACE %s + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}''' + self.assertRaises(AlreadyExists, session.execute, ddl % ksname) + + ddl = ''' + CREATE TABLE %s.%s ( + k int PRIMARY KEY, + v int )''' + self.assertRaises(AlreadyExists, session.execute, ddl % (ksname, cfname)) + + def test_replicas(self): + """ + Ensure cluster.metadata.get_replicas return correctly when not attached to keyspace + """ + cluster = Cluster() + self.assertEqual(cluster.metadata.get_replicas('key'), []) + + cluster.connect('test3rf') + host = list(cluster.metadata.get_replicas('key'))[0] + self.assertEqual(host.datacenter, 'datacenter1') + self.assertEqual(host.rack, 'rack1') + self.assertEqual(host.address, '127.0.0.2') + + def test_token_map(self): + """ + Test token mappings + """ + + cluster = Cluster() + cluster.connect('test3rf') + ring = cluster.metadata.token_map.ring + + self.assertEqual(list(cluster.metadata.token_map.get_replicas(ring[0]))[0].address, '127.0.0.1') + self.assertEqual(list(cluster.metadata.token_map.get_replicas(ring[1]))[0].address, '127.0.0.2') + self.assertEqual(list(cluster.metadata.token_map.get_replicas(ring[2]))[0].address, '127.0.0.3') + + # BUG: I was specifically trying to ensure that tokens wrap around + self.assertEqual(list(cluster.metadata.token_map.get_replicas(Murmur3Token(ring[0].value - 1)))[0].address, '127.0.0.3') + # self.assertEqual(list(cluster.metadata.token_map.get_replicas(Murmur3Token(ring[1].value - 1)))[0].address, '127.0.0.1') + # self.assertEqual(list(cluster.metadata.token_map.get_replicas(Murmur3Token(ring[2].value - 1)))[0].address, '127.0.0.2') + class TokenMetadataTest(unittest.TestCase): """ diff --git a/tests/integration/test_metrics.py b/tests/integration/test_metrics.py new file mode 100644 index 00000000..af2ea2be --- /dev/null +++ b/tests/integration/test_metrics.py @@ -0,0 +1,137 @@ +import unittest +from cassandra.query import SimpleStatement +from cassandra import ConsistencyLevel, WriteTimeout, Unavailable, ReadTimeout + +from cassandra.cluster import Cluster, NoHostAvailable +from cassandra.decoder import QueryMessage +from tests.integration import get_node, get_cluster + + +class MetricsTests(unittest.TestCase): + + def test_connection_error(self): + """ + Trigger and ensure connection_errors are counted + """ + + # DISCUSS: Doesn't trigger code coverage on cassandra.metrics.on_connection_error(). Find new example. + cluster = Cluster(metrics_enabled=True) + session = cluster.connect() + + # Test write + session.execute("USE test3rf") + + # Force kill cluster + get_cluster().stop(wait=True, gently=False) + try: + self.assertRaises(NoHostAvailable, session.execute, "USE test3rf") + finally: + get_cluster().start(wait_for_binary_proto=True) + + def test_write_timeout(self): + """ + Trigger and ensure write_timeouts are counted + Write a key, value pair. Force kill a node without waiting for the cluster to register the death. + Attempt a write at cl.ALL and receive a WriteTimeout. + """ + + cluster = Cluster(metrics_enabled=True) + session = cluster.connect() + + # Test write + session.execute("INSERT INTO test3rf.test (k, v) VALUES (1, 1)") + + # Assert read + query = SimpleStatement("SELECT v FROM test3rf.test WHERE k=%(k)s", consistency_level=ConsistencyLevel.ALL) + results = session.execute(query, {'k': 1}) + self.assertEqual(1, results[0].v) + + # Force kill ccm node + get_node(1).stop(wait=False, gently=False) + + try: + # Test write + query = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) + self.assertRaises(WriteTimeout, session.execute, query) + self.assertEqual(1, cluster.metrics.stats.write_timeouts) + + finally: + get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) + + def test_read_timeout(self): + """ + Trigger and ensure read_timeouts are counted + Write a key, value pair. Force kill a node without waiting for the cluster to register the death. + Attempt a read at cl.ALL and receive a ReadTimeout. + """ + + cluster = Cluster(metrics_enabled=True) + session = cluster.connect() + + # Test write + session.execute("INSERT INTO test3rf.test (k, v) VALUES (1, 1)") + + # Assert read + query = SimpleStatement("SELECT v FROM test3rf.test WHERE k=%(k)s", consistency_level=ConsistencyLevel.ALL) + results = session.execute(query, {'k': 1}) + self.assertEqual(1, results[0].v) + + # Force kill ccm node + get_node(1).stop(wait=False, gently=False) + + try: + # Test read + query = SimpleStatement("SELECT v FROM test3rf.test WHERE k=%(k)s", consistency_level=ConsistencyLevel.ALL) + self.assertRaises(ReadTimeout, session.execute, query, {'k': 1}) + self.assertEqual(1, cluster.metrics.stats.read_timeouts) + + finally: + get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) + + def test_unavailable(self): + """ + Trigger and ensure unavailables are counted + Write a key, value pair. Kill a node while waiting for the cluster to register the death. + Attempt an insert/read at cl.ALL and receive a Unavailable Exception. + """ + + cluster = Cluster(metrics_enabled=True) + session = cluster.connect() + + # Test write + session.execute("INSERT INTO test3rf.test (k, v) VALUES (1, 1)") + + # Assert read + query = SimpleStatement("SELECT v FROM test3rf.test WHERE k=%(k)s", consistency_level=ConsistencyLevel.ALL) + results = session.execute(query, {'k': 1}) + self.assertEqual(1, results[0].v) + + # Force kill ccm node + get_node(1).stop(wait=True, gently=True) + + try: + # Test write + query = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) + self.assertRaises(Unavailable, session.execute, query) + self.assertEqual(1, cluster.metrics.stats.unavailables) + + # Test write + query = SimpleStatement("SELECT v FROM test3rf.test WHERE k=%(k)s", consistency_level=ConsistencyLevel.ALL) + self.assertRaises(Unavailable, session.execute, query, {'k': 1}) + self.assertEqual(2, cluster.metrics.stats.unavailables) + finally: + get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) + + def test_other_error(self): + # TODO: Bootstrapping or Overloaded cases + pass + + + def test_ignore(self): + # TODO: Look for ways to generate ignores + pass + + + def test_retry(self): + # TODO: Look for ways to generate retries + pass diff --git a/tests/integration/test_prepared_statements.py b/tests/integration/test_prepared_statements.py index 1dc62c1f..4fe5e0fb 100644 --- a/tests/integration/test_prepared_statements.py +++ b/tests/integration/test_prepared_statements.py @@ -2,6 +2,7 @@ try: import unittest2 as unittest except ImportError: import unittest # noqa +from cassandra import InvalidRequest from cassandra.cluster import Cluster from cassandra.query import PreparedStatement @@ -9,6 +10,10 @@ from cassandra.query import PreparedStatement class PreparedStatementTests(unittest.TestCase): def test_basic(self): + """ + Test basic PreparedStatement usage + """ + cluster = Cluster() session = cluster.connect() session.execute( @@ -47,3 +52,91 @@ class PreparedStatementTests(unittest.TestCase): bound = prepared.bind(('a')) results = session.execute(bound) self.assertEquals(results, [('a', 'b', 'c')]) + + def test_missing_primary_key(self): + """ + Ensure an InvalidRequest is thrown + when prepared statements are missing the primary key + """ + + cluster = Cluster() + session = cluster.connect() + + prepared = session.prepare( + """ + INSERT INTO test3rf.test (v) VALUES (?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind((1,)) + self.assertRaises(InvalidRequest, session.execute, bound) + + def test_too_many_bind_values(self): + """ + Ensure a ValueError is thrown when attempting to bind too many variables + """ + + cluster = Cluster() + session = cluster.connect() + + prepared = session.prepare( + """ + INSERT INTO test3rf.test (v) VALUES (?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + self.assertRaises(ValueError, prepared.bind, (1,2)) + + def test_none_values(self): + """ + Ensure binding None is handled correctly + """ + + cluster = Cluster() + session = cluster.connect() + + prepared = session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind((1, None)) + session.execute(bound) + + prepared = session.prepare( + """ + SELECT * FROM test3rf.test WHERE k=? + """) + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind((1,)) + results = session.execute(bound) + self.assertEquals(results[0].v, None) + + def test_async_binding(self): + """ + Ensure None binding over async queries + """ + + cluster = Cluster() + session = cluster.connect() + + prepared = session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + future = session.execute_async(prepared, (873, None)) + future.result() + + prepared = session.prepare( + """ + SELECT * FROM test3rf.test WHERE k=? + """) + self.assertIsInstance(prepared, PreparedStatement) + + future = session.execute_async(prepared, (873,)) + results = future.result() + self.assertEquals(results[0].v, None) diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py new file mode 100644 index 00000000..0b00af39 --- /dev/null +++ b/tests/integration/test_query.py @@ -0,0 +1,45 @@ +import unittest +from cassandra.query import PreparedStatement, BoundStatement, ValueSequence, SimpleStatement +from cassandra.cluster import Cluster + + +class QueryTest(unittest.TestCase): + # TODO: Cover routing keys + # def test_query(self): + # cluster = Cluster() + # session = cluster.connect() + # + # prepared = session.prepare( + # """ + # INSERT INTO test3rf.test (k, v) VALUES (?, ?) + # """) + # + # self.assertIsInstance(prepared, PreparedStatement) + # bound = prepared.bind((1, None)) + # self.assertIsInstance(bound, BoundStatement) + # session.execute(bound) + # + # print bound.routing_key + + def test_value_sequence(self): + """ + Test the output of ValueSequences() + """ + + my_user_ids = ('alice', 'bob', 'charles') + self.assertEqual(str(ValueSequence(my_user_ids)), "( 'alice' , 'bob' , 'charles' )") + + def test_trace_prints_okay(self): + """ + Code coverage to ensure trace prints to string without error + """ + + cluster = Cluster() + session = cluster.connect() + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + session.execute(statement, trace=True) + + # Ensure this does not throw an exception + str(statement.trace) diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index d8d075f1..a5ae8302 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -9,7 +9,7 @@ from threading import Thread, Event from cassandra.cluster import Session from cassandra.connection import Connection, MAX_STREAM_PER_CONNECTION from cassandra.pool import Host, HostConnectionPool, NoConnectionsAvailable, HealthMonitor -from cassandra.policies import HostDistance +from cassandra.policies import HostDistance, SimpleConvictionPolicy class HostConnectionPoolTests(unittest.TestCase): @@ -202,3 +202,27 @@ class HostConnectionPoolTests(unittest.TestCase): # a new creation should be scheduled session.submit.assert_called_once() self.assertFalse(pool.is_shutdown) + + def test_host_instantiations(self): + """ + Ensure Host fails if not initialized properly + """ + + self.assertRaises(ValueError, Host, None, None) + self.assertRaises(ValueError, Host, '127.0.0.1', None) + self.assertRaises(ValueError, Host, None, SimpleConvictionPolicy) + + def test_host_equality(self): + """ + Test host equality has correct logic + """ + + a = Host('127.0.0.1', SimpleConvictionPolicy) + b = Host('127.0.0.1', SimpleConvictionPolicy) + c = Host('127.0.0.2', SimpleConvictionPolicy) + + self.assertEqual(a, b, 'Two Host instances should be equal when sharing.') + self.assertNotEqual(a, c, 'Two Host instances should NOT be equal when using two different addresses.') + self.assertNotEqual(b, c, 'Two Host instances should NOT be equal when using two different addresses.') + + self.assertFalse(a == '127.0.0.1') diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py new file mode 100644 index 00000000..83447a28 --- /dev/null +++ b/tests/unit/test_metadata.py @@ -0,0 +1,101 @@ +import unittest +import cassandra +from cassandra.cluster import Cluster +from cassandra.metadata import TableMetadata, Murmur3Token, MD5Token, BytesToken + + +class TestMetadata(unittest.TestCase): + + def test_protect_name(self): + """ + Test TableMetadata.protect_name output + """ + + table_metadata = TableMetadata('ks_name', 'table_name') + + self.assertEqual(table_metadata.protect_name('tests'), 'tests') + self.assertEqual(table_metadata.protect_name('test\'s'), '"test\'s"') + self.assertEqual(table_metadata.protect_name('test\'s'), "\"test's\"") + self.assertEqual(table_metadata.protect_name('tests ?!@#$%^&*()'), '"tests ?!@#$%^&*()"') + + # BUG: Or is this fine? + self.assertEqual(table_metadata.protect_name('1'), '"1"') + + def test_protect_names(self): + """ + Test TableMetadata.protect_names output + """ + + table_metadata = TableMetadata('ks_name', 'table_name') + + self.assertEqual(table_metadata.protect_names(['tests']), ['tests']) + self.assertEqual(table_metadata.protect_names( + [ + 'tests', + 'test\'s', + 'tests ?!@#$%^&*()', + '1' + ]), + [ + 'tests', + "\"test's\"", + '"tests ?!@#$%^&*()"', + '"1"' + ]) + + def test_protect_value(self): + """ + Test TableMetadata.protect_value output + """ + + table_metadata = TableMetadata('ks_name', 'table_name') + + self.assertEqual(table_metadata.protect_value(True), "'true'") + self.assertEqual(table_metadata.protect_value(False), "'false'") + self.assertEqual(table_metadata.protect_value(3.14), '3.140000') + self.assertEqual(table_metadata.protect_value(3), '3') + self.assertEqual(table_metadata.protect_value('test'), "'test'") + self.assertEqual(table_metadata.protect_value('test\'s'), "'test''s'") + + # BUG: Do we remove this altogether now? + self.assertEqual(table_metadata.protect_value(None), 'NULL') + + def test_is_valid_name(self): + """ + Test TableMetadata.is_valid_name output + """ + + table_metadata = TableMetadata('ks_name', 'table_name') + + self.assertEqual(table_metadata.is_valid_name(None), False) + self.assertEqual(table_metadata.is_valid_name('test'), True) + self.assertEqual(table_metadata.is_valid_name('Test'), False) + self.assertEqual(table_metadata.is_valid_name('t_____1'), True) + self.assertEqual(table_metadata.is_valid_name('test1'), True) + self.assertEqual(table_metadata.is_valid_name('1test1'), False) + + non_valid_keywords = cassandra.metadata._keywords - cassandra.metadata._unreserved_keywords + for keyword in non_valid_keywords: + self.assertEqual(table_metadata.is_valid_name(keyword), False) + + def test_token_values(self): + """ + Spot check token classes and values + """ + + # spot check murmur3 + murmur3_token = Murmur3Token(cassandra.metadata.MIN_LONG - 1) + self.assertEqual(murmur3_token.hash_fn('123'), -7468325962851647638) + self.assertEqual(murmur3_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 7162290910810015547) + + md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1) + # BUG: MD5Token always returns the same token + # self.assertNotEqual(md5_token.hash_fn('123'), 110673303387115207421586718101067225896) + # self.assertNotEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 110673303387115207421586718101067225896) + + bytes_token = BytesToken(cassandra.metadata.MIN_LONG - 1) + self.assertEqual(bytes_token.hash_fn('123'), '123') + self.assertEqual(bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)), str(cassandra.metadata.MAX_LONG)) + + # BUG? Should only accept strings? + self.assertEqual(bytes_token.hash_fn(123), '123') diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 0c625a4c..3ab2da9e 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -15,10 +15,31 @@ from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy, TokenAwarePolicy, SimpleConvictionPolicy, HostDistance, ExponentialReconnectionPolicy, RetryPolicy, WriteType, - DowngradingConsistencyRetryPolicy) + DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy, + LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy) from cassandra.pool import Host from cassandra.query import Statement + +class TestLoadBalancingPolicy(unittest.TestCase): + def test_non_implemented(self): + """ + Code coverage for interface-style base class + """ + + policy = LoadBalancingPolicy() + host = Host("ip1", SimpleConvictionPolicy) + host.set_location_info("dc1", "rack1") + + self.assertRaises(NotImplementedError, policy.distance, host) + self.assertRaises(NotImplementedError, policy.distance, host) + self.assertRaises(NotImplementedError, policy.make_query_plan) + self.assertRaises(NotImplementedError, policy.on_up, host) + self.assertRaises(NotImplementedError, policy.on_down, host) + self.assertRaises(NotImplementedError, policy.on_add, host) + self.assertRaises(NotImplementedError, policy.on_remove, host) + + class TestRoundRobinPolicy(unittest.TestCase): def test_basic(self): @@ -67,8 +88,23 @@ class TestRoundRobinPolicy(unittest.TestCase): map(lambda t: t.start(), threads) map(lambda t: t.join(), threads) + def test_no_live_nodes(self): + """ + Ensure query plan for a downed cluster will execute without errors + """ -class TestDCAwareRoundRobinPolicy(unittest.TestCase): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + + for i in range(4): + policy.on_down(i) + + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, []) + + +class DCAwareRoundRobinPolicyTest(unittest.TestCase): def test_no_remote(self): hosts = [] @@ -173,6 +209,47 @@ class TestDCAwareRoundRobinPolicy(unittest.TestCase): # since we have hosts in dc9000, the distance shouldn't be IGNORED self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE) + policy.on_down(new_local_host) + policy.on_down(hosts[1]) + qplan = list(policy.make_query_plan()) + self.assertEqual(set(qplan), set([hosts[3], new_remote_host])) + + policy.on_down(new_remote_host) + policy.on_down(hosts[3]) + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, []) + + def test_no_live_nodes(self): + """ + Ensure query plan for a downed cluster will execute without errors + """ + + hosts = [] + for i in range(4): + h = Host(i, SimpleConvictionPolicy) + h.set_location_info("dc1", "rack1") + hosts.append(h) + + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy.populate(None, hosts) + + for host in hosts: + policy.on_down(host) + + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, []) + + def test_no_nodes(self): + """ + Ensure query plan for an empty cluster will execute without errors + """ + + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy.populate(None, []) + + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, []) + class TokenAwarePolicyTest(unittest.TestCase): @@ -199,6 +276,19 @@ class TokenAwarePolicyTest(unittest.TestCase): self.assertEquals(replicas, qplan[:2]) self.assertEquals(other, set(qplan[2:])) + # Should use the secondary policy + for i in range(4): + query = Query() + qplan = list(policy.make_query_plan(query)) + + self.assertEquals(set(qplan), set(hosts)) + + # Should use the secondary policy + for i in range(4): + qplan = list(policy.make_query_plan()) + + self.assertEquals(set(qplan), set(hosts)) + def test_wrap_dc_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) @@ -239,6 +329,74 @@ class TokenAwarePolicyTest(unittest.TestCase): self.assertEquals(qplan[2].datacenter, "dc2") self.assertEquals(3, len(qplan)) +class ConvictionPolicyTest(unittest.TestCase): + def test_not_implemented(self): + """ + Code coverage for interface-style base class + """ + + conviction_policy = ConvictionPolicy(1) + self.assertRaises(NotImplementedError, conviction_policy.add_failure, 1) + self.assertRaises(NotImplementedError, conviction_policy.reset) + + +class SimpleConvictionPolicyTest(unittest.TestCase): + def test_basic_responses(self): + """ + Code coverage for SimpleConvictionPolicy + """ + + conviction_policy = SimpleConvictionPolicy(1) + + # DISCUSS: Always return True? + self.assertEqual(conviction_policy.add_failure(1), True) + + self.assertEqual(conviction_policy.reset(), None) + + +class ReconnectionPolicyTest(unittest.TestCase): + def test_basic_responses(self): + """ + Code coverage for interface-style base class + """ + + policy = ReconnectionPolicy() + self.assertRaises(NotImplementedError, policy.new_schedule) + + +class ConstantReconnectionPolicyTest(unittest.TestCase): + + def test_bad_vals(self): + """ + Test initialization values + """ + + self.assertRaises(ValueError, ConstantReconnectionPolicy, -1, 0) + + def test_schedule(self): + """ + Test ConstantReconnectionPolicy schedule + """ + + delay = 2 + max_attempts = 100 + policy = ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts) + schedule = list(policy.new_schedule()) + self.assertEqual(len(schedule), max_attempts) + for i, delay in enumerate(schedule): + self.assertEqual(delay, delay) + + def test_schedule_negative_max_attempts(self): + """ + Test how negative max_attempts are handled + """ + + delay = 2 + max_attempts = -100 + policy = ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts) + schedule = list(policy.new_schedule()) + self.assertEqual(len(schedule), 0) + class ExponentialReconnectionPolicyTest(unittest.TestCase): @@ -270,25 +428,28 @@ class RetryPolicyTest(unittest.TestCase): query=None, consistency="ONE", required_responses=1, received_responses=2, data_retrieved=True, retry_num=1) self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) # if we didn't get enough responses, rethrow retry, consistency = policy.on_read_timeout( query=None, consistency="ONE", required_responses=2, received_responses=1, data_retrieved=True, retry_num=0) self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) # if we got enough responses, but also got a data response, rethrow retry, consistency = policy.on_read_timeout( query=None, consistency="ONE", required_responses=2, received_responses=2, data_retrieved=True, retry_num=0) self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) - # we got enough reponses but no data response, so retry + # we got enough responses but no data response, so retry retry, consistency = policy.on_read_timeout( query=None, consistency="ONE", required_responses=2, received_responses=2, data_retrieved=False, retry_num=0) self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, "ONE") + self.assertEqual(consistency, 'ONE') def test_write_timeout(self): policy = RetryPolicy() @@ -298,19 +459,21 @@ class RetryPolicyTest(unittest.TestCase): query=None, consistency="ONE", write_type=WriteType.SIMPLE, required_responses=1, received_responses=2, retry_num=1) self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) # if it's not a BATCH_LOG write, don't retry it retry, consistency = policy.on_write_timeout( query=None, consistency="ONE", write_type=WriteType.SIMPLE, required_responses=1, received_responses=2, retry_num=0) self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) # retry BATCH_LOG writes regardless of received responses retry, consistency = policy.on_write_timeout( query=None, consistency="ONE", write_type=WriteType.BATCH_LOG, required_responses=10000, received_responses=1, retry_num=0) self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, "ONE") + self.assertEqual(consistency, 'ONE') class DowngradingConsistencyRetryPolicyTest(unittest.TestCase): @@ -323,6 +486,14 @@ class DowngradingConsistencyRetryPolicyTest(unittest.TestCase): query=None, consistency="ONE", required_responses=1, received_responses=2, data_retrieved=True, retry_num=1) self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + # if we didn't get enough responses, retry at a lower consistency + retry, consistency = policy.on_read_timeout( + query=None, consistency="ONE", required_responses=4, received_responses=3, + data_retrieved=True, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, ConsistencyLevel.THREE) # if we didn't get enough responses, retry at a lower consistency retry, consistency = policy.on_read_timeout( @@ -343,18 +514,21 @@ class DowngradingConsistencyRetryPolicyTest(unittest.TestCase): query=None, consistency="ONE", required_responses=3, received_responses=0, data_retrieved=True, retry_num=0) self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) # if we got enough response but no data, retry retry, consistency = policy.on_read_timeout( query=None, consistency="ONE", required_responses=3, received_responses=3, data_retrieved=False, retry_num=0) self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, 'ONE') # if we got enough responses, but also got a data response, rethrow retry, consistency = policy.on_read_timeout( query=None, consistency="ONE", required_responses=2, received_responses=2, data_retrieved=True, retry_num=0) self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) def test_write_timeout(self): policy = DowngradingConsistencyRetryPolicy() @@ -364,6 +538,7 @@ class DowngradingConsistencyRetryPolicyTest(unittest.TestCase): query=None, consistency="ONE", write_type=WriteType.SIMPLE, required_responses=1, received_responses=2, retry_num=1) self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) # ignore failures on these types of writes for write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): @@ -386,6 +561,13 @@ class DowngradingConsistencyRetryPolicyTest(unittest.TestCase): self.assertEqual(retry, RetryPolicy.RETRY) self.assertEqual(consistency, "ONE") + # timeout on an unknown write_type + retry, consistency = policy.on_write_timeout( + query=None, consistency="ONE", write_type=None, + required_responses=1, received_responses=2, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + def test_unavailable(self): policy = DowngradingConsistencyRetryPolicy() @@ -393,6 +575,7 @@ class DowngradingConsistencyRetryPolicyTest(unittest.TestCase): retry, consistency = policy.on_unavailable( query=None, consistency="ONE", required_replicas=3, alive_replicas=1, retry_num=1) self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) # downgrade consistency on unavailable exceptions retry, consistency = policy.on_unavailable( diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py new file mode 100644 index 00000000..e215e5fa --- /dev/null +++ b/tests/unit/test_types.py @@ -0,0 +1,289 @@ +import unittest +import datetime +import cassandra +from cassandra.cqltypes import lookup_cqltype, CassandraType, BooleanType, lookup_casstype_simple, lookup_casstype, \ + AsciiType, LongType, DecimalType, DoubleType, FloatType, Int32Type, UTF8Type, IntegerType, SetType, cql_typename + +from cassandra.cluster import Cluster + + +class TypeTests(unittest.TestCase): + + def test_lookup_casstype_simple(self): + """ + Ensure lookup_casstype_simple returns the correct classes + """ + + self.assertEqual(lookup_casstype_simple('AsciiType'), cassandra.cqltypes.AsciiType) + self.assertEqual(lookup_casstype_simple('LongType'), cassandra.cqltypes.LongType) + self.assertEqual(lookup_casstype_simple('BytesType'), cassandra.cqltypes.BytesType) + self.assertEqual(lookup_casstype_simple('BooleanType'), cassandra.cqltypes.BooleanType) + self.assertEqual(lookup_casstype_simple('CounterColumnType'), cassandra.cqltypes.CounterColumnType) + self.assertEqual(lookup_casstype_simple('DecimalType'), cassandra.cqltypes.DecimalType) + self.assertEqual(lookup_casstype_simple('DoubleType'), cassandra.cqltypes.DoubleType) + self.assertEqual(lookup_casstype_simple('FloatType'), cassandra.cqltypes.FloatType) + self.assertEqual(lookup_casstype_simple('InetAddressType'), cassandra.cqltypes.InetAddressType) + self.assertEqual(lookup_casstype_simple('Int32Type'), cassandra.cqltypes.Int32Type) + self.assertEqual(lookup_casstype_simple('UTF8Type'), cassandra.cqltypes.UTF8Type) + self.assertEqual(lookup_casstype_simple('DateType'), cassandra.cqltypes.DateType) + self.assertEqual(lookup_casstype_simple('TimeUUIDType'), cassandra.cqltypes.TimeUUIDType) + self.assertEqual(lookup_casstype_simple('UUIDType'), cassandra.cqltypes.UUIDType) + + # DISCUSS: varchar is not valid? + # self.assertEqual(lookup_casstype_simple('varchar'), cassandra.cqltypes.UTF8Type) + + self.assertEqual(lookup_casstype_simple('IntegerType'), cassandra.cqltypes.IntegerType) + self.assertEqual(lookup_casstype_simple('MapType'), cassandra.cqltypes.MapType) + self.assertEqual(lookup_casstype_simple('ListType'), cassandra.cqltypes.ListType) + self.assertEqual(lookup_casstype_simple('SetType'), cassandra.cqltypes.SetType) + self.assertEqual(lookup_casstype_simple('CompositeType'), cassandra.cqltypes.CompositeType) + self.assertEqual(lookup_casstype_simple('ColumnToCollectionType'), cassandra.cqltypes.ColumnToCollectionType) + self.assertEqual(lookup_casstype_simple('ReversedType'), cassandra.cqltypes.ReversedType) + + self.assertEqual(str(lookup_casstype_simple('unknown')), str(cassandra.cqltypes.mkUnrecognizedType('unknown'))) + + def test_lookup_casstype(self): + """ + Ensure lookup_casstype returns the correct classes + """ + + self.assertEqual(lookup_casstype('AsciiType'), cassandra.cqltypes.AsciiType) + self.assertEqual(lookup_casstype('LongType'), cassandra.cqltypes.LongType) + self.assertEqual(lookup_casstype('BytesType'), cassandra.cqltypes.BytesType) + self.assertEqual(lookup_casstype('BooleanType'), cassandra.cqltypes.BooleanType) + self.assertEqual(lookup_casstype('CounterColumnType'), cassandra.cqltypes.CounterColumnType) + self.assertEqual(lookup_casstype('DecimalType'), cassandra.cqltypes.DecimalType) + self.assertEqual(lookup_casstype('DoubleType'), cassandra.cqltypes.DoubleType) + self.assertEqual(lookup_casstype('FloatType'), cassandra.cqltypes.FloatType) + self.assertEqual(lookup_casstype('InetAddressType'), cassandra.cqltypes.InetAddressType) + self.assertEqual(lookup_casstype('Int32Type'), cassandra.cqltypes.Int32Type) + self.assertEqual(lookup_casstype('UTF8Type'), cassandra.cqltypes.UTF8Type) + self.assertEqual(lookup_casstype('DateType'), cassandra.cqltypes.DateType) + self.assertEqual(lookup_casstype('TimeUUIDType'), cassandra.cqltypes.TimeUUIDType) + self.assertEqual(lookup_casstype('UUIDType'), cassandra.cqltypes.UUIDType) + + # DISCUSS: varchar is not valid? + # self.assertEqual(lookup_casstype('varchar'), cassandra.cqltypes.UTF8Type) + + self.assertEqual(lookup_casstype('IntegerType'), cassandra.cqltypes.IntegerType) + self.assertEqual(lookup_casstype('MapType'), cassandra.cqltypes.MapType) + self.assertEqual(lookup_casstype('ListType'), cassandra.cqltypes.ListType) + self.assertEqual(lookup_casstype('SetType'), cassandra.cqltypes.SetType) + self.assertEqual(lookup_casstype('CompositeType'), cassandra.cqltypes.CompositeType) + self.assertEqual(lookup_casstype('ColumnToCollectionType'), cassandra.cqltypes.ColumnToCollectionType) + self.assertEqual(lookup_casstype('ReversedType'), cassandra.cqltypes.ReversedType) + + self.assertEqual(str(lookup_casstype('unknown')), str(cassandra.cqltypes.mkUnrecognizedType('unknown'))) + + self.assertRaises(ValueError, lookup_casstype, 'AsciiType~') + + # DISCUSS: Figure out if other tests are needed + self.assertEqual(str(lookup_casstype(BooleanType(True))), str(BooleanType(True))) + + + def test_lookup_cqltype(self): + """ + Ensure lookup_cqltype returns the correct class + """ + + self.assertEqual(lookup_cqltype('ascii'), cassandra.cqltypes.AsciiType) + self.assertEqual(lookup_cqltype('bigint'), cassandra.cqltypes.LongType) + self.assertEqual(lookup_cqltype('blob'), cassandra.cqltypes.BytesType) + self.assertEqual(lookup_cqltype('boolean'), cassandra.cqltypes.BooleanType) + self.assertEqual(lookup_cqltype('counter'), cassandra.cqltypes.CounterColumnType) + self.assertEqual(lookup_cqltype('decimal'), cassandra.cqltypes.DecimalType) + self.assertEqual(lookup_cqltype('double'), cassandra.cqltypes.DoubleType) + self.assertEqual(lookup_cqltype('float'), cassandra.cqltypes.FloatType) + self.assertEqual(lookup_cqltype('inet'), cassandra.cqltypes.InetAddressType) + self.assertEqual(lookup_cqltype('int'), cassandra.cqltypes.Int32Type) + self.assertEqual(lookup_cqltype('text'), cassandra.cqltypes.UTF8Type) + self.assertEqual(lookup_cqltype('timestamp'), cassandra.cqltypes.DateType) + self.assertEqual(lookup_cqltype('timeuuid'), cassandra.cqltypes.TimeUUIDType) + self.assertEqual(lookup_cqltype('uuid'), cassandra.cqltypes.UUIDType) + self.assertEqual(lookup_cqltype('varchar'), cassandra.cqltypes.UTF8Type) + self.assertEqual(lookup_cqltype('varint'), cassandra.cqltypes.IntegerType) + + + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.AsciiType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.LongType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.BytesType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.BooleanType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.CounterColumnType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.DecimalType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.DoubleType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.FloatType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.InetAddressType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.Int32Type))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.UTF8Type))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.DateType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.TimeUUIDType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.UUIDType))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.UTF8Type))) + self.assertEqual(str(lookup_cqltype('list')), + str(cassandra.cqltypes.ListType.apply_parameters(cassandra.cqltypes.IntegerType))) + + + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.AsciiType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.LongType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.BytesType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.BooleanType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.CounterColumnType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.DecimalType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.DoubleType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.FloatType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.InetAddressType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.Int32Type))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.UTF8Type))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.DateType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.TimeUUIDType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.UUIDType))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.UTF8Type))) + self.assertEqual(str(lookup_cqltype('set')), + str(cassandra.cqltypes.SetType.apply_parameters(cassandra.cqltypes.IntegerType))) + + + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.AsciiType, + cassandra.cqltypes.AsciiType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.LongType, + cassandra.cqltypes.LongType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.BytesType, + cassandra.cqltypes.BytesType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.BooleanType, + cassandra.cqltypes.BooleanType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.CounterColumnType, + cassandra.cqltypes.CounterColumnType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.DecimalType, + cassandra.cqltypes.DecimalType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.DoubleType, + cassandra.cqltypes.DoubleType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.FloatType, + cassandra.cqltypes.FloatType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.InetAddressType, + cassandra.cqltypes.InetAddressType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.Int32Type, + cassandra.cqltypes.Int32Type))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.UTF8Type, + cassandra.cqltypes.UTF8Type))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.DateType, + cassandra.cqltypes.DateType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.TimeUUIDType, + cassandra.cqltypes.TimeUUIDType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.UUIDType, + cassandra.cqltypes.UUIDType))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.UTF8Type, + cassandra.cqltypes.UTF8Type))) + self.assertEqual(str(lookup_cqltype('map')), + str(cassandra.cqltypes.MapType.apply_parameters(cassandra.cqltypes.IntegerType, + cassandra.cqltypes.IntegerType))) + + # DISCUSS: Figure out if other tests are needed, and how to test them + # self.assertEqual(str(lookup_cqltype(AsciiType(CassandraType('asdf')))), str(AsciiType(CassandraType('asdf')))) + # self.assertEqual(str(lookup_cqltype(LongType(CassandraType(1234)))), str(LongType(CassandraType(1234)))) + # self.assertEqual(str(lookup_cqltype(BytesType(CassandraType(True)))), str(BytesType(CassandraType(True)))) + self.assertEqual(str(lookup_cqltype(BooleanType(CassandraType(True)))), str(BooleanType(CassandraType(True)))) + # self.assertEqual(str(lookup_cqltype(CounterColumnType(CassandraType(True)))), str(CounterColumnType(CassandraType(True)))) + # self.assertEqual(str(lookup_cqltype(DecimalType(CassandraType(1234.1234)))), str(DecimalType(CassandraType(1234.1234)))) + # self.assertEqual(str(lookup_cqltype(DoubleType(CassandraType(1234.1234)))), str(DoubleType(CassandraType(1234.1234)))) + # self.assertEqual(str(lookup_cqltype(FloatType(CassandraType(1234.1234)))), str(FloatType(CassandraType(1234.1234)))) + # self.assertEqual(str(lookup_cqltype(InetAddressType(CassandraType(True)))), str(InetAddressType(CassandraType(True)))) + # self.assertEqual(str(lookup_cqltype(Int32Type(CassandraType(1234)))), str(Int32Type(CassandraType(1234)))) + # self.assertEqual(str(lookup_cqltype(UTF8Type(CassandraType('asdf')))), str(UTF8Type(CassandraType('asdf')))) + # self.assertEqual(str(lookup_cqltype(DateType(CassandraType(True)))), str(DateType(CassandraType(True)))) + # self.assertEqual(str(lookup_cqltype(TimeUUIDType(CassandraType(True)))), str(TimeUUIDType(CassandraType(True)))) + # self.assertEqual(str(lookup_cqltype(UUIDType(CassandraType(True)))), str(UUIDType(CassandraType(True)))) + # self.assertEqual(str(lookup_cqltype(IntegerType(CassandraType(1234)))), str(IntegerType(CassandraType(1234)))) + + # DISCUSS: Check if typo in code, or misunderstanding + # self.assertEqual(lookup_cqltype("'ascii'"), cassandra.cqltypes.AsciiType) + # self.assertEqual(lookup_cqltype("'bigint'"), cassandra.cqltypes.LongType) + # self.assertEqual(lookup_cqltype("'blob'"), cassandra.cqltypes.BytesType) + # self.assertEqual(lookup_cqltype("'boolean'"), cassandra.cqltypes.BooleanType) + # self.assertEqual(lookup_cqltype("'counter'"), cassandra.cqltypes.CounterColumnType) + # self.assertEqual(lookup_cqltype("'decimal'"), cassandra.cqltypes.DecimalType) + # self.assertEqual(lookup_cqltype("'float'"), cassandra.cqltypes.FloatType) + # self.assertEqual(lookup_cqltype("'inet'"), cassandra.cqltypes.InetAddressType) + # self.assertEqual(lookup_cqltype("'int'"), cassandra.cqltypes.Int32Type) + # self.assertEqual(lookup_cqltype("'text'"), cassandra.cqltypes.UTF8Type) + # self.assertEqual(lookup_cqltype("'timestamp'"), cassandra.cqltypes.DateType) + # self.assertEqual(lookup_cqltype("'timeuuid'"), cassandra.cqltypes.TimeUUIDType) + # self.assertEqual(lookup_cqltype("'uuid'"), cassandra.cqltypes.UUIDType) + # self.assertEqual(lookup_cqltype("'varchar'"), cassandra.cqltypes.UTF8Type) + # self.assertEqual(lookup_cqltype("'varint'"), cassandra.cqltypes.IntegerType) + + def test_cassandratype(self): + """ + Smoke test cass_parameterized_type_with + """ + + self.assertEqual(LongType.cass_parameterized_type_with(()), 'LongType') + self.assertEqual(LongType.cass_parameterized_type_with((), full=True), 'org.apache.cassandra.db.marshal.LongType') + self.assertEqual(SetType.cass_parameterized_type_with([DecimalType], full=True), 'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)') + + self.assertEqual(LongType.cql_parameterized_type(), 'bigint') + self.assertEqual(cassandra.cqltypes.MapType.apply_parameters( + cassandra.cqltypes.UTF8Type, cassandra.cqltypes.UTF8Type).cql_parameterized_type(), + 'map') + + def test_datetype(self): + """ + Test cassandra.cqltypes.DateType() construction + """ + + pass + # TODO: Figure out the required format here + # date_string = str(datetime.datetime.now().strftime('%s.%f')) + # print date_string + # print cassandra.cqltypes.DateType(date_string) + + def test_cql_typename(self): + """ + Smoke test cql_typename + """ + + self.assertEqual(cql_typename('DateType'), 'timestamp') + self.assertEqual(cql_typename('org.apache.cassandra.db.marshal.ListType(IntegerType)'), 'list')