diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a3618ca1..0e4fb15e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,31 @@ +3.3.0 +===== +May 2, 2016 + +Features +-------- +* Add an AddressTranslator interface (PYTHON-69) +* New Retry Policy Decision - try next host (PYTHON-285) +* Don't mark host down on timeout (PYTHON-286) +* SSL hostname verification (PYTHON-296) +* Add C* version to metadata or cluster objects (PYTHON-301) +* Expose listen_address of node we get ring information from (PYTHON-332) +* Use A-record with multiple IPs for contact points (PYTHON-415) +* Custom consistency level for populating query traces (PYTHON-435) +* Normalize Server Exception Types (PYTHON-443) +* Propagate exception message when DDL schema agreement fails (PYTHON-444) +* Specialized exceptions for metadata refresh methods failure (PYTHON-527) + +Bug Fixes +--------- +* Resolve contact point hostnames to avoid duplicate hosts (PYTHON-103) +* Options to Disable Schema, Token Metadata Processing (PYTHON-327) +* GeventConnection stalls requests when read is a multiple of the input buffer size (PYTHON-429) +* named_tuple_factory breaks with duplicate "cleaned" col names (PYTHON-467) +* Connection leak if Cluster.shutdown() happens during reconnection (PYTHON-482) +* HostConnection.borrow_connection does not block when all request ids are used (PYTHON-514) +* Empty field not being handled by the NumpyProtocolHandler (PYTHON-550) + 3.2.2 ===== April 19, 2016 diff --git a/benchmarks/base.py b/benchmarks/base.py index 812db42a..e9184697 100644 --- a/benchmarks/base.py +++ b/benchmarks/base.py @@ -19,6 +19,7 @@ import sys from threading import Thread import time from optparse import OptionParser +import uuid from greplin import scales @@ -59,49 +60,65 @@ except ImportError as exc: KEYSPACE = "testkeyspace" + str(int(time.time())) TABLE = "testtable" +COLUMN_VALUES = { + 'int': 42, + 'text': "'42'", + 'float': 42.0, + 'uuid': uuid.uuid4(), + 'timestamp': "'2016-02-03 04:05+0000'" +} -def setup(hosts): + +def setup(options): log.info("Using 'cassandra' package from %s", cassandra.__path__) - cluster = Cluster(hosts, protocol_version=1) - cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) + cluster = Cluster(options.hosts, protocol_version=options.protocol_version) try: session = cluster.connect() log.debug("Creating keyspace...") - session.execute(""" - CREATE KEYSPACE %s - WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } - """ % KEYSPACE) + try: + session.execute(""" + CREATE KEYSPACE %s + WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } + """ % options.keyspace) - log.debug("Setting keyspace...") - session.set_keyspace(KEYSPACE) + log.debug("Setting keyspace...") + except cassandra.AlreadyExists: + log.debug("Keyspace already exists") + + session.set_keyspace(options.keyspace) log.debug("Creating table...") - session.execute(""" - CREATE TABLE %s ( + create_table_query = """ + CREATE TABLE {} ( thekey text, - col1 text, - col2 text, - PRIMARY KEY (thekey, col1) - ) - """ % TABLE) + """ + for i in range(options.num_columns): + create_table_query += "col{} {},\n".format(i, options.column_type) + create_table_query += "PRIMARY KEY (thekey))" + + try: + session.execute(create_table_query.format(TABLE)) + except cassandra.AlreadyExists: + log.debug("Table already exists.") + finally: cluster.shutdown() -def teardown(hosts): - cluster = Cluster(hosts, protocol_version=1) - cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) +def teardown(options): + cluster = Cluster(options.hosts, protocol_version=options.protocol_version) session = cluster.connect() - session.execute("DROP KEYSPACE " + KEYSPACE) + if not options.keep_data: + session.execute("DROP KEYSPACE " + options.keyspace) cluster.shutdown() def benchmark(thread_class): options, args = parse_options() for conn_class in options.supported_reactors: - setup(options.hosts) + setup(options) log.info("==== %s ====" % (conn_class.__name__,)) kwargs = {'metrics_enabled': options.enable_metrics, @@ -109,20 +126,30 @@ def benchmark(thread_class): if options.protocol_version: kwargs['protocol_version'] = options.protocol_version cluster = Cluster(options.hosts, **kwargs) - session = cluster.connect(KEYSPACE) + session = cluster.connect(options.keyspace) log.debug("Sleeping for two seconds...") time.sleep(2.0) - query = session.prepare(""" - INSERT INTO {table} (thekey, col1, col2) VALUES (?, ?, ?) - """.format(table=TABLE)) - values = ('key', 'a', 'b') + # Generate the query + if options.read: + query = "SELECT * FROM {} WHERE thekey = '{{key}}'".format(TABLE) + else: + query = "INSERT INTO {} (thekey".format(TABLE) + for i in range(options.num_columns): + query += ", col{}".format(i) + + query += ") VALUES ('{key}'" + for i in range(options.num_columns): + query += ", {}".format(COLUMN_VALUES[options.column_type]) + query += ")" + + values = None # we don't use that anymore. Keeping it in case we go back to prepared statements. per_thread = options.num_ops // options.threads threads = [] - log.debug("Beginning inserts...") + log.debug("Beginning {}...".format('reads' if options.read else 'inserts')) start = time.time() try: for i in range(options.threads): @@ -142,7 +169,7 @@ def benchmark(thread_class): end = time.time() finally: cluster.shutdown() - teardown(options.hosts) + teardown(options) total = end - start log.info("Total time: %0.2fs" % total) @@ -190,8 +217,19 @@ def parse_options(): help='logging level: debug, info, warning, or error') parser.add_option('-p', '--profile', action='store_true', dest='profile', help='Profile the run') - parser.add_option('--protocol-version', type='int', dest='protocol_version', + parser.add_option('--protocol-version', type='int', dest='protocol_version', default=4, help='Native protocol version to use') + parser.add_option('-c', '--num-columns', type='int', dest='num_columns', default=2, + help='Specify the number of columns for the schema') + parser.add_option('-k', '--keyspace', type='str', dest='keyspace', default=KEYSPACE, + help='Specify the keyspace name for the schema') + parser.add_option('--keep-data', action='store_true', dest='keep_data', default=False, + help='Keep the data after the benchmark') + parser.add_option('--column-type', type='str', dest='column_type', default='text', + help='Specify the column type for the schema (supported: int, text, float, uuid, timestamp)') + parser.add_option('--read', action='store_true', dest='read', default=False, + help='Read mode') + options, args = parser.parse_args() @@ -235,6 +273,9 @@ class BenchmarkThread(Thread): if self.profiler: self.profiler.enable() + def run_query(self, key, **kwargs): + return self.session.execute_async(self.query.format(key=key), **kwargs) + def finish_profile(self): if self.profiler: self.profiler.disable() diff --git a/benchmarks/callback_full_pipeline.py b/benchmarks/callback_full_pipeline.py index 3736991b..a7990d80 100644 --- a/benchmarks/callback_full_pipeline.py +++ b/benchmarks/callback_full_pipeline.py @@ -41,8 +41,10 @@ class Runner(BenchmarkThread): if next(self.num_finished) >= self.num_queries: self.event.set() - if next(self.num_started) <= self.num_queries: - future = self.session.execute_async(self.query, self.values, timeout=None) + i = next(self.num_started) + if i <= self.num_queries: + key = "{}-{}".format(self.thread_num, i) + future = self.run_query(key, timeout=None) future.add_callbacks(self.insert_next, self.insert_next) def run(self): diff --git a/benchmarks/future_batches.py b/benchmarks/future_batches.py index 91c250bc..c8305369 100644 --- a/benchmarks/future_batches.py +++ b/benchmarks/future_batches.py @@ -35,7 +35,8 @@ class Runner(BenchmarkThread): except queue.Empty: break - future = self.session.execute_async(self.query, self.values) + key = "{}-{}".format(self.thread_num, i) + future = self.run_query(key) futures.put_nowait(future) while True: diff --git a/benchmarks/future_full_pipeline.py b/benchmarks/future_full_pipeline.py index 40682e04..ecc2ce6f 100644 --- a/benchmarks/future_full_pipeline.py +++ b/benchmarks/future_full_pipeline.py @@ -31,7 +31,8 @@ class Runner(BenchmarkThread): old_future = futures.get_nowait() old_future.result() - future = self.session.execute_async(self.query, self.values) + key = "{}-{}".format(self.thread_num, i) + future = self.run_query(key) futures.put_nowait(future) while True: diff --git a/benchmarks/future_full_throttle.py b/benchmarks/future_full_throttle.py index 27d87442..2e47b19f 100644 --- a/benchmarks/future_full_throttle.py +++ b/benchmarks/future_full_throttle.py @@ -25,8 +25,9 @@ class Runner(BenchmarkThread): self.start_profile() - for _ in range(self.num_queries): - future = self.session.execute_async(self.query, self.values) + for i in range(self.num_queries): + key = "{}-{}".format(self.thread_num, i) + future = self.run_query(key) futures.append(future) for future in futures: diff --git a/build.yaml b/build.yaml index ee7dd4d9..e9f4b0c8 100644 --- a/build.yaml +++ b/build.yaml @@ -27,8 +27,9 @@ build: fi pip install -r test-requirements.txt pip install nose-ignore-docstring - + FORCE_CYTHON=False if [[ $CYTHON == 'CYTHON' ]]; then + FORCE_CYTHON=True pip install cython pip install numpy # Install the driver & compile C extensions @@ -39,12 +40,12 @@ build: fi echo "==========RUNNING CQLENGINE TESTS==========" - CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=cqle_results.xml tests/integration/cqlengine/ || true + CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=cqle_results.xml tests/integration/cqlengine/ || true echo "==========RUNNING INTEGRATION TESTS==========" - CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/ || true + CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/ || true echo "==========RUNNING LONG INTEGRATION TESTS==========" - CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=long_results.xml tests/integration/long/ || true + CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=long_results.xml tests/integration/long/ || true - xunit: - "*_results.xml" diff --git a/cassandra/__init__.py b/cassandra/__init__.py index c93058b0..b6a16347 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -22,7 +22,7 @@ class NullHandler(logging.Handler): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 2, 2) +__version_info__ = (3, 3, 0) __version__ = '.'.join(map(str, __version_info__)) @@ -194,7 +194,21 @@ class UserAggregateDescriptor(SignatureDescriptor): """ -class Unavailable(Exception): +class DriverException(Exception): + """ + Base for all exceptions explicitly raised by the driver. + """ + pass + + +class RequestExecutionException(DriverException): + """ + Base for request execution exceptions returned from the server. + """ + pass + + +class Unavailable(RequestExecutionException): """ There were not enough live replicas to satisfy the requested consistency level, so the coordinator node immediately failed the request without @@ -220,7 +234,7 @@ class Unavailable(Exception): 'alive_replicas': alive_replicas})) -class Timeout(Exception): +class Timeout(RequestExecutionException): """ Replicas failed to respond to the coordinator node before timing out. """ @@ -289,7 +303,7 @@ class WriteTimeout(Timeout): self.write_type = write_type -class CoordinationFailure(Exception): +class CoordinationFailure(RequestExecutionException): """ Replicas sent a failure to the coordinator. """ @@ -359,7 +373,7 @@ class WriteFailure(CoordinationFailure): self.write_type = write_type -class FunctionFailure(Exception): +class FunctionFailure(RequestExecutionException): """ User Defined Function failed during execution """ @@ -386,7 +400,21 @@ class FunctionFailure(Exception): Exception.__init__(self, summary_message) -class AlreadyExists(Exception): +class RequestValidationException(DriverException): + """ + Server request validation failed + """ + pass + + +class ConfigurationException(RequestValidationException): + """ + Server indicated request errro due to current configuration + """ + pass + + +class AlreadyExists(ConfigurationException): """ An attempt was made to create a keyspace or table that already exists. """ @@ -414,7 +442,7 @@ class AlreadyExists(Exception): self.table = table -class InvalidRequest(Exception): +class InvalidRequest(RequestValidationException): """ A query was made that was invalid for some reason, such as trying to set the keyspace for a connection to a nonexistent keyspace. @@ -422,21 +450,21 @@ class InvalidRequest(Exception): pass -class Unauthorized(Exception): +class Unauthorized(RequestValidationException): """ - The current user is not authorized to perfom the requested operation. + The current user is not authorized to perform the requested operation. """ pass -class AuthenticationFailed(Exception): +class AuthenticationFailed(DriverException): """ Failed to authenticate. """ pass -class OperationTimedOut(Exception): +class OperationTimedOut(DriverException): """ The operation took longer than the specified (client-side) timeout to complete. This is not an error generated by Cassandra, only @@ -460,7 +488,7 @@ class OperationTimedOut(Exception): Exception.__init__(self, message) -class UnsupportedOperation(Exception): +class UnsupportedOperation(DriverException): """ An attempt was made to use a feature that is not supported by the selected protocol version. See :attr:`Cluster.protocol_version` diff --git a/cassandra/cluster.py b/cassandra/cluster.py index d8c1026e..58830007 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -45,7 +45,7 @@ from itertools import groupby, count from cassandra import (ConsistencyLevel, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, - SchemaTargetType) + SchemaTargetType, DriverException) from cassandra.connection import (ConnectionException, ConnectionShutdown, ConnectionHeartbeat, ProtocolVersionUnsupported) from cassandra.cqltypes import UserType @@ -65,7 +65,7 @@ from cassandra.protocol import (QueryMessage, ResultMessage, from cassandra.metadata import Metadata, protect_name, murmur3 from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, ExponentialReconnectionPolicy, HostDistance, - RetryPolicy) + RetryPolicy, IdentityTranslator) from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, HostConnectionPool, HostConnection, NoConnectionsAvailable) @@ -347,6 +347,12 @@ class Cluster(object): :class:`.policies.SimpleConvictionPolicy`. """ + address_translator = IdentityTranslator() + """ + :class:`.policies.AddressTranslator` instance to be used in translating server node addresses + to driver connection addresses. + """ + connect_to_remote_hosts = True """ If left as :const:`True`, hosts that are considered :attr:`~.HostDistance.REMOTE` @@ -380,6 +386,14 @@ class Cluster(object): a string pointing to the location of the CA certs file), and you probably want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLSv1`` to match Cassandra's default protocol. + + .. versionchanged:: 3.3.0 + + In addition to ``wrap_socket`` kwargs, clients may also specify ``'check_hostname': True`` to verify the cert hostname + as outlined in RFC 2818 and RFC 6125. Note that this requires the certificate to be transferred, so + should almost always require the option ``'cert_reqs': ssl.CERT_REQUIRED``. Note also that this functionality was not built into + Python standard library until (2.7.9, 3.2). To enable this mechanism in earlier versions, patch ``ssl.match_hostname`` + with a custom or `back-ported function `_. """ sockopts = None @@ -481,6 +495,37 @@ class Cluster(object): establishment, options passing, and authentication. """ + @property + def schema_metadata_enabled(self): + """ + Flag indicating whether internal schema metadata is updated. + + When disabled, the driver does not populate Cluster.metadata.keyspaces on connect, or on schema change events. This + can be used to speed initial connection, and reduce load on client and server during operation. Turning this off + gives away token aware request routing, and programmatic inspection of the metadata model. + """ + return self.control_connection._schema_meta_enabled + + @schema_metadata_enabled.setter + def schema_metadata_enabled(self, enabled): + self.control_connection._schema_meta_enabled = bool(enabled) + + @property + def token_metadata_enabled(self): + """ + Flag indicating whether internal token metadata is updated. + + When disabled, the driver does not query node token information on connect, or on topology change events. This + can be used to speed initial connection, and reduce load on client and server during operation. It is most useful + in large clusters using vnodes, where the token map can be expensive to compute. Turning this off + gives away token aware request routing, and programmatic inspection of the token ring. + """ + return self.control_connection._token_meta_enabled + + @token_metadata_enabled.setter + def token_metadata_enabled(self, enabled): + self.control_connection._token_meta_enabled = bool(enabled) + sessions = None control_connection = None scheduler = None @@ -520,7 +565,10 @@ class Cluster(object): idle_heartbeat_interval=30, schema_event_refresh_window=2, topology_event_refresh_window=10, - connect_timeout=5): + connect_timeout=5, + schema_metadata_enabled=True, + token_metadata_enabled=True, + address_translator=None): """ Any of the mutable Cluster attributes may be set as keyword arguments to the constructor. @@ -529,9 +577,15 @@ class Cluster(object): if isinstance(contact_points, six.string_types): raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings") + if None in contact_points: + raise ValueError("contact_points should not contain None (it can resolve to localhost)") self.contact_points = contact_points self.port = port + + self.contact_points_resolved = [endpoint[4][0] for a in self.contact_points + for endpoint in socket.getaddrinfo(a, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM)] + self.compression = compression self.protocol_version = protocol_version self.auth_provider = auth_provider @@ -539,7 +593,6 @@ class Cluster(object): if load_balancing_policy is not None: if isinstance(load_balancing_policy, type): raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class") - self.load_balancing_policy = load_balancing_policy else: self.load_balancing_policy = default_lbp_factory() @@ -547,13 +600,11 @@ class Cluster(object): if reconnection_policy is not None: if isinstance(reconnection_policy, type): raise TypeError("reconnection_policy should not be a class, it should be an instance of that class") - self.reconnection_policy = reconnection_policy if default_retry_policy is not None: if isinstance(default_retry_policy, type): raise TypeError("default_retry_policy should not be a class, it should be an instance of that class") - self.default_retry_policy = default_retry_policy if conviction_policy_factory is not None: @@ -561,6 +612,11 @@ class Cluster(object): raise ValueError("conviction_policy_factory must be callable") self.conviction_policy_factory = conviction_policy_factory + if address_translator is not None: + if isinstance(address_translator, type): + raise TypeError("address_translator should not be a class, it should be an instance of that class") + self.address_translator = address_translator + if connection_class is not None: self.connection_class = connection_class @@ -619,7 +675,9 @@ class Cluster(object): self.control_connection = ControlConnection( self, self.control_connection_timeout, - self.schema_event_refresh_window, self.topology_event_refresh_window) + self.schema_event_refresh_window, self.topology_event_refresh_window, + schema_metadata_enabled, token_metadata_enabled) + def register_user_type(self, keyspace, user_type, klass): """ @@ -813,7 +871,7 @@ class Cluster(object): log.warning("Downgrading core protocol version from %d to %d for %s", self.protocol_version, new_version, host_addr) self.protocol_version = new_version else: - raise Exception("Cannot downgrade protocol version (%d) below minimum supported version: %d" % (new_version, MIN_SUPPORTED_VERSION)) + raise DriverException("Cannot downgrade protocol version (%d) below minimum supported version: %d" % (new_version, MIN_SUPPORTED_VERSION)) def connect(self, keyspace=None): """ @@ -823,14 +881,14 @@ class Cluster(object): """ with self._lock: if self.is_shutdown: - raise Exception("Cluster is already shut down") + raise DriverException("Cluster is already shut down") if not self._is_setup: log.debug("Connecting to cluster, contact points: %s; protocol version: %s", self.contact_points, self.protocol_version) self.connection_class.initialize_reactor() atexit.register(partial(_shutdown_cluster, self)) - for address in self.contact_points: + for address in self.contact_points_resolved: host, new = self.add_host(address, signal=False) if new: host.set_up() @@ -1244,8 +1302,8 @@ class Cluster(object): An Exception is raised if schema refresh fails for any reason. """ - if not self.control_connection.refresh_schema(schema_agreement_wait=max_schema_agreement_wait): - raise Exception("Schema metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("Schema metadata was not refreshed. See log for details.") def refresh_keyspace_metadata(self, keyspace, max_schema_agreement_wait=None): """ @@ -1255,8 +1313,8 @@ class Cluster(object): See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.KEYSPACE, keyspace=keyspace, - schema_agreement_wait=max_schema_agreement_wait): - raise Exception("Keyspace metadata was not refreshed. See log for details.") + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("Keyspace metadata was not refreshed. See log for details.") def refresh_table_metadata(self, keyspace, table, max_schema_agreement_wait=None): """ @@ -1265,8 +1323,9 @@ class Cluster(object): See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=table, schema_agreement_wait=max_schema_agreement_wait): - raise Exception("Table metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=table, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("Table metadata was not refreshed. See log for details.") def refresh_materialized_view_metadata(self, keyspace, view, max_schema_agreement_wait=None): """ @@ -1274,8 +1333,9 @@ class Cluster(object): See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=view, schema_agreement_wait=max_schema_agreement_wait): - raise Exception("View metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=view, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("View metadata was not refreshed. See log for details.") def refresh_user_type_metadata(self, keyspace, user_type, max_schema_agreement_wait=None): """ @@ -1283,8 +1343,9 @@ class Cluster(object): See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TYPE, keyspace=keyspace, type=user_type, schema_agreement_wait=max_schema_agreement_wait): - raise Exception("User Type metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TYPE, keyspace=keyspace, type=user_type, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("User Type metadata was not refreshed. See log for details.") def refresh_user_function_metadata(self, keyspace, function, max_schema_agreement_wait=None): """ @@ -1294,8 +1355,9 @@ class Cluster(object): See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.FUNCTION, keyspace=keyspace, function=function, schema_agreement_wait=max_schema_agreement_wait): - raise Exception("User Function metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.FUNCTION, keyspace=keyspace, function=function, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("User Function metadata was not refreshed. See log for details.") def refresh_user_aggregate_metadata(self, keyspace, aggregate, max_schema_agreement_wait=None): """ @@ -1305,8 +1367,9 @@ class Cluster(object): See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.AGGREGATE, keyspace=keyspace, aggregate=aggregate, schema_agreement_wait=max_schema_agreement_wait): - raise Exception("User Aggregate metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.AGGREGATE, keyspace=keyspace, aggregate=aggregate, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("User Aggregate metadata was not refreshed. See log for details.") def refresh_nodes(self): """ @@ -1315,10 +1378,12 @@ class Cluster(object): An Exception is raised if node refresh fails for any reason. """ if not self.control_connection.refresh_node_list_and_token_map(): - raise Exception("Node list was not refreshed. See log for details.") + raise DriverException("Node list was not refreshed. See log for details.") def set_meta_refresh_enabled(self, enabled): """ + *Deprecated:* set :attr:`~.Cluster.schema_metadata_enabled` :attr:`~.Cluster.token_metadata_enabled` instead + Sets a flag to enable (True) or disable (False) all metadata refresh queries. This applies to both schema and node topology. @@ -1327,7 +1392,8 @@ class Cluster(object): Meta refresh must be enabled for the driver to become aware of any cluster topology changes or schema updates. """ - self.control_connection.set_meta_refresh_enabled(bool(enabled)) + self.schema_metadata_enabled = enabled + self.token_metadata_enabled = enabled def _prepare_all_queries(self, host): if not self._prepared_statements: @@ -2009,8 +2075,10 @@ class ControlConnection(object): Internal """ - _SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address, schema_version FROM system.peers" - _SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner, release_version, schema_version FROM system.local WHERE key='local'" + _SELECT_PEERS = "SELECT * FROM system.peers" + _SELECT_PEERS_NO_TOKENS = "SELECT peer, data_center, rack, rpc_address, release_version, schema_version FROM system.peers" + _SELECT_LOCAL = "SELECT * FROM system.local WHERE key='local'" + _SELECT_LOCAL_NO_TOKENS = "SELECT cluster_name, data_center, rack, partitioner, release_version, schema_version FROM system.local WHERE key='local'" _SELECT_SCHEMA_PEERS = "SELECT peer, rpc_address, schema_version FROM system.peers" _SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'" @@ -2022,14 +2090,17 @@ class ControlConnection(object): _schema_event_refresh_window = None _topology_event_refresh_window = None - _meta_refresh_enabled = True + _schema_meta_enabled = True + _token_meta_enabled = True # for testing purposes _time = time def __init__(self, cluster, timeout, schema_event_refresh_window, - topology_event_refresh_window): + topology_event_refresh_window, + schema_meta_enabled=True, + token_meta_enabled=True): # use a weak reference to allow the Cluster instance to be GC'ed (and # shutdown) since implementing __del__ disables the cycle detector self._cluster = weakref.proxy(cluster) @@ -2038,6 +2109,8 @@ class ControlConnection(object): self._schema_event_refresh_window = schema_event_refresh_window self._topology_event_refresh_window = topology_event_refresh_window + self._schema_meta_enabled = schema_meta_enabled + self._token_meta_enabled = token_meta_enabled self._lock = RLock() self._schema_agreement_lock = Lock() @@ -2086,6 +2159,8 @@ class ControlConnection(object): except Exception as exc: errors[host.address] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) + if self._is_shutdown: + raise DriverException("[control connection] Reconnection in progress during shutdown") raise NoHostAvailable("Unable to connect to any servers", errors) @@ -2099,6 +2174,9 @@ class ControlConnection(object): while True: try: connection = self._cluster.connection_factory(host.address, is_control_connection=True) + if self._is_shutdown: + connection.close() + raise DriverException("Reconnecting during shutdown") break except ProtocolVersionUnsupported as e: self._cluster.protocol_downgrade(host.address, e.startup_version) @@ -2119,8 +2197,10 @@ class ControlConnection(object): "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') }, register_timeout=self._timeout) - peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=ConsistencyLevel.ONE) - local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=ConsistencyLevel.ONE) + sel_peers = self._SELECT_PEERS if self._token_meta_enabled else self._SELECT_PEERS_NO_TOKENS + sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS + peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE) + local_query = QueryMessage(query=sel_local, consistency_level=ConsistencyLevel.ONE) shared_results = connection.wait_for_responses( peers_query, local_query, timeout=self._timeout) @@ -2198,16 +2278,12 @@ class ControlConnection(object): log.debug("Shutting down control connection") if self._connection: self._connection.close() - del self._connection - - def refresh_schema(self, **kwargs): - if not self._meta_refresh_enabled: - log.debug("[control connection] Skipping schema refresh because meta refresh is disabled") - return False + self._connection = None + def refresh_schema(self, force=False, **kwargs): try: if self._connection: - return self._refresh_schema(self._connection, **kwargs) + return self._refresh_schema(self._connection, force=force, **kwargs) except ReferenceError: pass # our weak reference to the Cluster is no good except Exception: @@ -2215,13 +2291,18 @@ class ControlConnection(object): self._signal_error() return False - def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_wait=None, **kwargs): + def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_wait=None, force=False, **kwargs): if self._cluster.is_shutdown: return False agreed = self.wait_for_schema_agreement(connection, preloaded_results=preloaded_results, wait_time=schema_agreement_wait) + + if not self._schema_meta_enabled and not force: + log.debug("[control connection] Skipping schema refresh because schema metadata is disabled") + return False + if not agreed: log.debug("Skipping schema refresh due to lack of schema agreement") return False @@ -2231,10 +2312,6 @@ class ControlConnection(object): return True def refresh_node_list_and_token_map(self, force_token_rebuild=False): - if not self._meta_refresh_enabled: - log.debug("[control connection] Skipping node list refresh because meta refresh is disabled") - return False - try: if self._connection: self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild) @@ -2254,10 +2331,17 @@ class ControlConnection(object): peers_result = preloaded_results[0] local_result = preloaded_results[1] else: - log.debug("[control connection] Refreshing node list and token map") cl = ConsistencyLevel.ONE - peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl) - local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl) + if not self._token_meta_enabled: + log.debug("[control connection] Refreshing node list without token map") + sel_peers = self._SELECT_PEERS_NO_TOKENS + sel_local = self._SELECT_LOCAL_NO_TOKENS + else: + log.debug("[control connection] Refreshing node list and token map") + sel_peers = self._SELECT_PEERS + sel_local = self._SELECT_LOCAL + peers_query = QueryMessage(query=sel_peers, consistency_level=cl) + local_query = QueryMessage(query=sel_local, consistency_level=cl) peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=self._timeout) @@ -2277,27 +2361,25 @@ class ControlConnection(object): datacenter = local_row.get("data_center") rack = local_row.get("rack") self._update_location_info(host, datacenter, rack) + host.listen_address = local_row.get("listen_address") + host.broadcast_address = local_row.get("broadcast_address") + host.release_version = local_row.get("release_version") partitioner = local_row.get("partitioner") tokens = local_row.get("tokens") if partitioner and tokens: token_map[host] = tokens - connection.server_version = local_row['release_version'] - # Check metadata.partitioner to see if we haven't built anything yet. If # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None found_hosts = set() for row in peers_result: - addr = row.get("rpc_address") + addr = self._rpc_from_peer_row(row) - if not addr or addr in ["0.0.0.0", "::"]: - addr = row.get("peer") - - tokens = row.get("tokens") - if not tokens: + tokens = row.get("tokens", None) + if 'tokens' in row and not tokens: # it was selected, but empty log.warning("Excluding host (%s) with no tokens in system.peers table of %s." % (addr, connection.host)) continue @@ -2313,6 +2395,9 @@ class ControlConnection(object): else: should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) + host.broadcast_address = row.get("peer") + host.release_version = row.get("release_version") + if partitioner and tokens: token_map[host] = tokens @@ -2320,7 +2405,7 @@ class ControlConnection(object): if old_host.address != connection.host and old_host.address not in found_hosts: should_rebuild_token_map = True if old_host.address not in self._cluster.contact_points: - log.debug("[control connection] Found host that has been removed: %r", old_host) + log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) self._cluster.remove_host(old_host) log.debug("[control connection] Finished fetching ring info") @@ -2356,7 +2441,7 @@ class ControlConnection(object): def _handle_topology_change(self, event): change_type = event["change_type"] - addr, port = event["address"] + addr = self._translate_address(event["address"][0]) if change_type == "NEW_NODE" or change_type == "MOVED_NODE": if self._topology_event_refresh_window >= 0: delay = self._delay_for_event_type('topology_change', self._topology_event_refresh_window) @@ -2367,7 +2452,7 @@ class ControlConnection(object): def _handle_status_change(self, event): change_type = event["change_type"] - addr, port = event["address"] + addr = self._translate_address(event["address"][0]) host = self._cluster.metadata.get_host(addr) if change_type == "UP": delay = 1 + self._delay_for_event_type('status_change', 0.5) # randomness to avoid thundering herd problem on events @@ -2385,6 +2470,9 @@ class ControlConnection(object): # this will be run by the scheduler self._cluster.on_down(host, is_host_addition=False) + def _translate_address(self, addr): + return self._cluster.address_translator.translate(addr) + def _handle_schema_change(self, event): if self._schema_event_refresh_window < 0: return @@ -2466,11 +2554,7 @@ class ControlConnection(object): schema_ver = row.get('schema_version') if not schema_ver: continue - - addr = row.get("rpc_address") - if not addr or addr in ["0.0.0.0", "::"]: - addr = row.get("peer") - + addr = self._rpc_from_peer_row(row) peer = self._cluster.metadata.get_host(addr) if peer and peer.is_up: versions[schema_ver].add(addr) @@ -2481,6 +2565,12 @@ class ControlConnection(object): return dict((version, list(nodes)) for version, nodes in six.iteritems(versions)) + def _rpc_from_peer_row(self, row): + addr = row.get("rpc_address") + if not addr or addr in ["0.0.0.0", "::"]: + addr = row.get("peer") + return self._translate_address(addr) + def _signal_error(self): with self._lock: if self._is_shutdown: @@ -2529,9 +2619,6 @@ class ControlConnection(object): if connection is self._connection and (connection.is_defunct or connection.is_closed): self.reconnect() - def set_meta_refresh_enabled(self, enabled): - self._meta_refresh_enabled = enabled - def _stop_scheduler(scheduler, thread): try: @@ -2626,13 +2713,9 @@ class _Scheduler(object): def refresh_schema_and_set_result(control_conn, response_future, **kwargs): try: - if control_conn._meta_refresh_enabled: - log.debug("Refreshing schema in response to schema change. " - "%s", kwargs) - response_future.is_schema_agreed = control_conn._refresh_schema(response_future._connection, **kwargs) - else: - log.debug("Skipping schema refresh in response to schema change because meta refresh is disabled; " - "%s", kwargs) + log.debug("Refreshing schema in response to schema change. " + "%s", kwargs) + response_future.is_schema_agreed = control_conn._refresh_schema(response_future._connection, **kwargs) except Exception: log.exception("Exception refreshing schema in response to schema change:") response_future.session.submit(control_conn.refresh_schema, **kwargs) @@ -2717,7 +2800,16 @@ class ResponseFuture(object): self._timer.cancel() def _on_timeout(self): - self._set_final_exception(OperationTimedOut(self._errors, self._current_host)) + errors = self._errors + if not errors: + if self.is_schema_agreed: + errors = {self._current_host.address: "Client request timeout. See Session.execute[_async](timeout)"} + else: + connection = getattr(self.session.cluster.control_connection, '_connection') + host = connection.host if connection else 'unknown' + errors = {host: "Request timed out while waiting for schema agreement. See Session.execute[_async](timeout) and Cluster.max_schema_agreement_wait."} + + self._set_final_exception(OperationTimedOut(errors, self._current_host)) def _make_query_plan(self): # convert the list/generator/etc to an iterator so that subsequent @@ -2730,6 +2822,7 @@ class ResponseFuture(object): """ Internal """ # query_plan is an iterator, so this will resume where we last left # off if send_request() is called multiple times + start = time.time() for host in self.query_plan: req_id = self._query(host) if req_id is not None: @@ -2741,6 +2834,9 @@ class ResponseFuture(object): if self._timer is None: self._start_timer() return + if self.timeout is not None and time.time() - start > self.timeout: + self._on_timeout() + return self._set_final_exception(NoHostAvailable( "Unable to complete the operation against any hosts", self._errors)) @@ -2809,7 +2905,7 @@ class ResponseFuture(object): """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): - raise Exception("warnings cannot be retrieved before ResponseFuture is finalized") + raise DriverException("warnings cannot be retrieved before ResponseFuture is finalized") return self._warnings @property @@ -2827,7 +2923,7 @@ class ResponseFuture(object): """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): - raise Exception("custom_payload cannot be retrieved before ResponseFuture is finalized") + raise DriverException("custom_payload cannot be retrieved before ResponseFuture is finalized") return self._custom_payload def start_fetching_next_page(self): @@ -2982,15 +3078,17 @@ class ResponseFuture(object): return retry_type, consistency = retry - if retry_type is RetryPolicy.RETRY: + if retry_type in (RetryPolicy.RETRY, RetryPolicy.RETRY_NEXT_HOST): self._query_retries += 1 - self._retry(reuse_connection=True, consistency_level=consistency) + reuse = retry_type == RetryPolicy.RETRY + self._retry(reuse_connection=reuse, consistency_level=consistency) elif retry_type is RetryPolicy.RETHROW: self._set_final_exception(response.to_exception()) else: # IGNORE if self._metrics is not None: self._metrics.on_ignore() self._set_final_result(None) + self._errors[self._current_host] = response.to_exception() elif isinstance(response, ConnectionException): if self._metrics is not None: self._metrics.on_connection_error() @@ -3150,32 +3248,34 @@ class ResponseFuture(object): """ return [trace.trace_id for trace in self._query_traces] - def get_query_trace(self, max_wait=None): + def get_query_trace(self, max_wait=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ Fetches and returns the query trace of the last response, or `None` if tracing was not enabled. Note that this may raise an exception if there are problems retrieving the trace - details from Cassandra. If the trace is not available after `max_wait_sec`, + details from Cassandra. If the trace is not available after `max_wait`, :exc:`cassandra.query.TraceUnavailable` will be raised. + + `query_cl` is the consistency level used to poll the trace tables. """ if self._query_traces: - return self._get_query_trace(len(self._query_traces) - 1, max_wait) + return self._get_query_trace(len(self._query_traces) - 1, max_wait, query_cl) - def get_all_query_traces(self, max_wait_per=None): + def get_all_query_traces(self, max_wait_per=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ Fetches and returns the query traces for all query pages, if tracing was enabled. See note in :meth:`~.get_query_trace` regarding possible exceptions. """ if self._query_traces: - return [self._get_query_trace(i, max_wait_per) for i in range(len(self._query_traces))] + return [self._get_query_trace(i, max_wait_per, query_cl) for i in range(len(self._query_traces))] return [] - def _get_query_trace(self, i, max_wait): + def _get_query_trace(self, i, max_wait, query_cl): trace = self._query_traces[i] if not trace.events: - trace.populate(max_wait=max_wait) + trace.populate(max_wait=max_wait, query_cl=query_cl) return trace def add_callback(self, fn, *args, **kwargs): diff --git a/cassandra/connection.py b/cassandra/connection.py index 0624ad13..ae353cbc 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -125,6 +125,7 @@ class _Frame(object): NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK) + class ConnectionException(Exception): """ An unrecoverable error was hit when attempting to use a connection, @@ -210,6 +211,12 @@ class Connection(object): # the number of request IDs that are currently in use. in_flight = 0 + # Max concurrent requests allowed per connection. This is set optimistically high, allowing + # all request ids to be used in protocol version 3+. Normally concurrency would be controlled + # at a higher level by the application or concurrent.execute_concurrent. This attribute + # is for lower-level integrations that want some upper bound without reimplementing. + max_in_flight = 2 ** 15 + # A set of available request IDs. When using the v3 protocol or higher, # this will not initially include all request IDs in order to save memory, # but the set will grow if it is exhausted. @@ -231,8 +238,6 @@ class Connection(object): is_control_connection = False signaled_error = False # used for flagging at the pool level - _server_version = None - _iobuf = None _current_frame = None @@ -241,6 +246,8 @@ class Connection(object): _socket_impl = socket _ssl_impl = ssl + _check_hostname = False + def __init__(self, host='127.0.0.1', port=9042, authenticator=None, ssl_options=None, sockopts=None, compression=True, cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False, @@ -248,7 +255,7 @@ class Connection(object): self.host = host self.port = port self.authenticator = authenticator - self.ssl_options = ssl_options + self.ssl_options = ssl_options.copy() if ssl_options else None self.sockopts = sockopts self.compression = compression self.cql_version = cql_version @@ -260,14 +267,22 @@ class Connection(object): self._requests = {} self._iobuf = io.BytesIO() + if ssl_options: + self._check_hostname = bool(self.ssl_options.pop('check_hostname', False)) + if self._check_hostname: + if not getattr(ssl, 'match_hostname', None): + raise RuntimeError("ssl_options specify 'check_hostname', but ssl.match_hostname is not provided. " + "Patch or upgrade Python to use this option.") + if protocol_version >= 3: - self.max_request_id = (2 ** 15) - 1 - # Don't fill the deque with 2**15 items right away. Start with 300 and add + self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1) + # Don't fill the deque with 2**15 items right away. Start with some and add # more if needed. - self.request_ids = deque(range(300)) - self.highest_request_id = 299 + initial_size = min(300, self.max_in_flight) + self.request_ids = deque(range(initial_size)) + self.highest_request_id = initial_size - 1 else: - self.max_request_id = (2 ** 7) - 1 + self.max_request_id = min(self.max_in_flight, (2 ** 7) - 1) self.request_ids = deque(range(self.max_request_id + 1)) self.highest_request_id = self.max_request_id @@ -319,15 +334,19 @@ class Connection(object): def _connect_socket(self): sockerr = None addresses = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) + if not addresses: + raise ConnectionException("getaddrinfo returned empty list for %s" % (self.host,)) for (af, socktype, proto, canonname, sockaddr) in addresses: try: self._socket = self._socket_impl.socket(af, socktype, proto) if self.ssl_options: if not self._ssl_impl: - raise Exception("This version of Python was not compiled with SSL support") + raise RuntimeError("This version of Python was not compiled with SSL support") self._socket = self._ssl_impl.wrap_socket(self._socket, **self.ssl_options) self._socket.settimeout(self.connect_timeout) self._socket.connect(sockaddr) + if self._check_hostname: + ssl.match_hostname(self._socket.getpeercert(), self.host) sockerr = None break except socket.error as err: @@ -461,7 +480,7 @@ class Connection(object): while True: needed = len(msgs) - messages_sent with self.lock: - available = min(needed, self.max_request_id - self.in_flight) + available = min(needed, self.max_request_id - self.in_flight + 1) request_ids = [self.get_request_id() for _ in range(available)] self.in_flight += available @@ -826,18 +845,6 @@ class Connection(object): def reset_idle(self): self.msg_received = False - @property - def server_version(self): - if self._server_version is None: - query_message = QueryMessage(query="SELECT release_version FROM system.local", consistency_level=ConsistencyLevel.ONE) - message = self.wait_for_response(query_message) - self._server_version = message.results[1][0][0] # (col names, rows)[rows][first row][only item] - return self._server_version - - @server_version.setter - def server_version(self, version): - self._server_version = version - def __str__(self): status = "" if self.is_defunct: @@ -908,7 +915,7 @@ class HeartbeatFuture(object): log.debug("Sending options message heartbeat on idle connection (%s) %s", id(connection), connection.host) with connection.lock: - if connection.in_flight < connection.max_request_id: + if connection.in_flight <= connection.max_request_id: connection.in_flight += 1 connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) else: @@ -921,18 +928,18 @@ class HeartbeatFuture(object): if self._exception: raise self._exception else: - raise OperationTimedOut() + raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,), self.connection.host) def _options_callback(self, response): - if not isinstance(response, SupportedMessage): + if isinstance(response, SupportedMessage): + log.debug("Received options response on connection (%s) from %s", + id(self.connection), self.connection.host) + else: if isinstance(response, ConnectionException): self._exception = response else: self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s" % (response,)) - - log.debug("Received options response on connection (%s) from %s", - id(self.connection), self.connection.host) self._event.set() @@ -964,10 +971,10 @@ class ConnectionHeartbeat(Thread): if connection.is_idle: try: futures.append(HeartbeatFuture(connection, owner)) - except Exception: + except Exception as e: log.warning("Failed sending heartbeat message on connection (%s) to %s", - id(connection), connection.host, exc_info=True) - failed_connections.append((connection, owner)) + id(connection), connection.host) + failed_connections.append((connection, owner, e)) else: connection.reset_idle() else: @@ -984,14 +991,14 @@ class ConnectionHeartbeat(Thread): with connection.lock: connection.in_flight -= 1 connection.reset_idle() - except Exception: + except Exception as e: log.warning("Heartbeat failed for connection (%s) to %s", - id(connection), connection.host, exc_info=True) - failed_connections.append((f.connection, f.owner)) + id(connection), connection.host) + failed_connections.append((f.connection, f.owner, e)) - for connection, owner in failed_connections: + for connection, owner, exc in failed_connections: self._raise_if_stopped() - connection.defunct(Exception('Connection heartbeat failure')) + connection.defunct(exc) owner.return_connection(connection) except self.ShutdownException: pass diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx index 9ffcdd5d..6c29ebde 100644 --- a/cassandra/deserializers.pyx +++ b/cassandra/deserializers.pyx @@ -44,6 +44,8 @@ cdef class Deserializer: cdef class DesBytesType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return b"" return to_bytes(buf) # this is to facilitate cqlsh integration, which requires bytearrays for BytesType @@ -51,6 +53,8 @@ cdef class DesBytesType(Deserializer): # deserializers.DesBytesType = deserializers.DesBytesTypeByteArray cdef class DesBytesTypeByteArray(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return bytearray() return bytearray(buf.ptr[:buf.size]) # TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html @@ -84,6 +88,8 @@ cdef class DesByteType(Deserializer): cdef class DesAsciiType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return "" if PY2: return to_bytes(buf) return to_bytes(buf).decode('ascii') @@ -169,6 +175,8 @@ cdef class DesTimeType(Deserializer): cdef class DesUTF8Type(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return "" cdef val = to_bytes(buf) return val.decode('utf8') diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py index 6e62a38b..c825a8c1 100644 --- a/cassandra/io/geventreactor.py +++ b/cassandra/io/geventreactor.py @@ -14,17 +14,16 @@ import gevent import gevent.event from gevent.queue import Queue -from gevent import select, socket +from gevent import socket import gevent.ssl -from functools import partial import logging import os import time from six.moves import range -from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL +from errno import EINVAL from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager @@ -34,9 +33,8 @@ log = logging.getLogger(__name__) def is_timeout(err): return ( - err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or (err == EINVAL and os.name in ('nt', 'ce')) or - isinstance(err, socket.timeout) + isinstance(err, socket.timeout) ) @@ -118,44 +116,23 @@ class GeventConnection(Connection): self.close() def handle_write(self): - run_select = partial(select.select, (), (self._socket,), ()) while True: try: next_msg = self._write_queue.get() - run_select() - except Exception as exc: - if not self.is_closed: - log.debug("Exception during write select() for %s: %s", self, exc) - self.defunct(exc) - return - - try: self._socket.sendall(next_msg) except socket.error as err: - log.debug("Exception during socket sendall for %s: %s", self, err) + log.debug("Exception in send for %s: %s", self, err) self.defunct(err) - return # Leave the write loop - - def handle_read(self): - run_select = partial(select.select, (self._socket,), (), ()) - while True: - try: - run_select() - except Exception as exc: - if not self.is_closed: - log.debug("Exception during read select() for %s: %s", self, exc) - self.defunct(exc) return + def handle_read(self): + while True: try: - while True: - buf = self._socket.recv(self.in_buffer_size) - self._iobuf.write(buf) - if len(buf) < self.in_buffer_size: - break + buf = self._socket.recv(self.in_buffer_size) + self._iobuf.write(buf) except socket.error as err: if not is_timeout(err): - log.debug("Exception during socket recv for %s: %s", self, err) + log.debug("Exception in read for %s: %s", self, err) self.defunct(err) return # leave the read loop diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 1d04b4c9..e9bf0b4b 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -117,13 +117,15 @@ class Metadata(object): def refresh(self, connection, timeout, target_type=None, change_type=None, **kwargs): + server_version = self.get_host(connection.host).release_version + parser = get_schema_parser(connection, server_version, timeout) + if not target_type: - self._rebuild_all(connection, timeout) + self._rebuild_all(parser) return tt_lower = target_type.lower() try: - parser = get_schema_parser(connection, timeout) parse_method = getattr(parser, 'get_' + tt_lower) meta = parse_method(self.keyspaces, **kwargs) if meta: @@ -135,12 +137,7 @@ class Metadata(object): except AttributeError: raise ValueError("Unknown schema target_type: '%s'" % target_type) - def _rebuild_all(self, connection, timeout): - """ - For internal use only. - """ - parser = get_schema_parser(connection, timeout) - + def _rebuild_all(self, parser): current_keyspaces = set() for keyspace_meta in parser.get_all_keyspaces(): current_keyspaces.add(keyspace_meta.name) @@ -1402,8 +1399,10 @@ class TokenMap(object): with self._rebuild_lock: current = self.tokens_to_hosts_by_ks.get(keyspace, None) if (build_if_absent and current is None) or (not build_if_absent and current is not None): - replica_map = self.replica_map_for_keyspace(self._metadata.keyspaces[keyspace]) - self.tokens_to_hosts_by_ks[keyspace] = replica_map + ks_meta = self._metadata.keyspaces.get(keyspace) + if ks_meta: + replica_map = self.replica_map_for_keyspace(self._metadata.keyspaces[keyspace]) + self.tokens_to_hosts_by_ks[keyspace] = replica_map def replica_map_for_keyspace(self, ks_metadata): strategy = ks_metadata.replication_strategy @@ -2453,8 +2452,7 @@ class MaterializedViewMetadata(object): return self.as_cql_query(formatted=True) + ";" -def get_schema_parser(connection, timeout): - server_version = connection.server_version +def get_schema_parser(connection, server_version, timeout): if server_version.startswith('3'): return SchemaParserV3(connection, timeout) else: diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx index 920a3efd..1334e747 100644 --- a/cassandra/numpy_parser.pyx +++ b/cassandra/numpy_parser.pyx @@ -57,7 +57,7 @@ arrDescDtype = np.dtype( [ ('buf_ptr', np.uintp) , ('stride', np.dtype('i')) , ('is_object', np.dtype('i')) - ]) + ], align=True) _cqltype_to_numpy = { cqltypes.LongType: np.dtype('>i8'), @@ -145,13 +145,13 @@ cdef inline int unpack_row( get_buf(reader, &buf) arr = arrays[i] - if buf.size == 0: - raise ValueError("Cannot handle NULL value") if arr.is_object: deserializer = desc.deserializers[i] val = from_binary(deserializer, &buf, desc.protocol_version) Py_INCREF(val) ( arr.buf_ptr)[0] = val + elif buf.size < 0: + raise ValueError("Cannot handle NULL value") else: memcpy( arr.buf_ptr, buf.ptr, buf.size) diff --git a/cassandra/policies.py b/cassandra/policies.py index 595717ca..4725f068 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -16,11 +16,8 @@ from itertools import islice, cycle, groupby, repeat import logging from random import randint from threading import Lock -import six -from cassandra import ConsistencyLevel - -from six.moves import range +from cassandra import ConsistencyLevel, OperationTimedOut log = logging.getLogger(__name__) @@ -235,7 +232,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): self._dc_live_hosts[dc] = tuple(set(dc_hosts)) if not self.local_dc: - self._contact_points = cluster.contact_points + self._contact_points = cluster.contact_points_resolved self._position = randint(0, len(hosts) - 1) if hosts else 0 @@ -337,7 +334,7 @@ class TokenAwarePolicy(LoadBalancingPolicy): def check_supported(self): if not self._cluster_metadata.can_support_partitioner(): - raise Exception( + raise RuntimeError( '%s cannot be used with the cluster partitioner (%s) because ' 'the relevant C extension for this driver was not compiled. ' 'See the installation instructions for details on building ' @@ -468,7 +465,7 @@ class SimpleConvictionPolicy(ConvictionPolicy): """ def add_failure(self, connection_exc): - return True + return not isinstance(connection_exc, OperationTimedOut) def reset(self): pass @@ -554,8 +551,8 @@ class ExponentialReconnectionPolicy(ReconnectionPolicy): self.max_attempts = max_attempts def new_schedule(self): - i=0 - while self.max_attempts == None or i < self.max_attempts: + i = 0 + while self.max_attempts is None or i < self.max_attempts: yield min(self.base_delay * (2 ** i), self.max_delay) i += 1 @@ -650,6 +647,12 @@ class RetryPolicy(object): should be ignored but no more retries should be attempted. """ + RETRY_NEXT_HOST = 3 + """ + This should be returned from the below methods if the operation + should be retried on another connection. + """ + def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): """ @@ -677,11 +680,11 @@ class RetryPolicy(object): a sufficient number of replicas responded (with data digests). """ if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None elif received_responses >= required_responses and not data_retrieved: - return (self.RETRY, consistency) + return self.RETRY, consistency else: - return (self.RETHROW, None) + return self.RETHROW, None def on_write_timeout(self, query, consistency, write_type, required_responses, received_responses, retry_num): @@ -710,11 +713,11 @@ class RetryPolicy(object): :attr:`~.WriteType.BATCH_LOG`. """ if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None elif write_type == WriteType.BATCH_LOG: - return (self.RETRY, consistency) + return self.RETRY, consistency else: - return (self.RETHROW, None) + return self.RETHROW, None def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): """ @@ -739,7 +742,7 @@ class RetryPolicy(object): By default, no retries will be attempted and the error will be re-raised. """ - return (self.RETHROW, None) + return (self.RETRY_NEXT_HOST, consistency) if retry_num == 0 else (self.RETHROW, None) class FallthroughRetryPolicy(RetryPolicy): @@ -749,13 +752,13 @@ class FallthroughRetryPolicy(RetryPolicy): """ def on_read_timeout(self, *args, **kwargs): - return (self.RETHROW, None) + return self.RETHROW, None def on_write_timeout(self, *args, **kwargs): - return (self.RETHROW, None) + return self.RETHROW, None def on_unavailable(self, *args, **kwargs): - return (self.RETHROW, None) + return self.RETHROW, None class DowngradingConsistencyRetryPolicy(RetryPolicy): @@ -807,45 +810,73 @@ class DowngradingConsistencyRetryPolicy(RetryPolicy): """ def _pick_consistency(self, num_responses): if num_responses >= 3: - return (self.RETRY, ConsistencyLevel.THREE) + return self.RETRY, ConsistencyLevel.THREE elif num_responses >= 2: - return (self.RETRY, ConsistencyLevel.TWO) + return self.RETRY, ConsistencyLevel.TWO elif num_responses >= 1: - return (self.RETRY, ConsistencyLevel.ONE) + return self.RETRY, ConsistencyLevel.ONE else: - return (self.RETHROW, None) + return self.RETHROW, None def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None elif received_responses < required_responses: return self._pick_consistency(received_responses) elif not data_retrieved: - return (self.RETRY, consistency) + return self.RETRY, consistency else: - return (self.RETHROW, None) + return self.RETHROW, None def on_write_timeout(self, query, consistency, write_type, required_responses, received_responses, retry_num): if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None if write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): if received_responses > 0: # persisted on at least one replica - return (self.IGNORE, None) + return self.IGNORE, None else: - return (self.RETHROW, None) + return self.RETHROW, None elif write_type == WriteType.UNLOGGED_BATCH: return self._pick_consistency(received_responses) elif write_type == WriteType.BATCH_LOG: - return (self.RETRY, consistency) + return self.RETRY, consistency - return (self.RETHROW, None) + return self.RETHROW, None def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None else: return self._pick_consistency(alive_replicas) + + +class AddressTranslator(object): + """ + Interface for translating cluster-defined endpoints. + + The driver discovers nodes using server metadata and topology change events. Normally, + the endpoint defined by the server is the right way to connect to a node. In some environments, + these addresses may not be reachable, or not preferred (public vs. private IPs in cloud environments, + suboptimal routing, etc). This interface allows for translating from server defined endpoints to + preferred addresses for driver connections. + + *Note:* :attr:`~Cluster.contact_points` provided while creating the :class:`~.Cluster` instance are not + translated using this mechanism -- only addresses received from Cassandra nodes are. + """ + def translate(self, addr): + """ + Accepts the node ip address, and returns a translated address to be used connecting to this node. + """ + raise NotImplementedError + + +class IdentityTranslator(AddressTranslator): + """ + Returns the endpoint with no translation + """ + def translate(self, addr): + return addr diff --git a/cassandra/pool.py b/cassandra/pool.py index 08cfd372..134eb254 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -48,7 +48,21 @@ class Host(object): address = None """ - The IP address or hostname of the node. + The IP address of the node. This is the RPC address the driver uses when connecting to the node + """ + + broadcast_address = None + """ + broadcast address configured for the node, *if available* ('peer' in system.peers table). + This is not present in the ``system.local`` table for older versions of Cassandra. It is also not queried if + :attr:`~.Cluster.token_metadata_enabled` is ``False``. + """ + + listen_address = None + """ + listen address configured for the node, *if available*. This is only available in the ``system.local`` table for newer + versions of Cassandra. It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + Usually the same as ``broadcast_address`` unless configured differently in cassandra.yaml. """ conviction_policy = None @@ -64,6 +78,11 @@ class Host(object): up or down. """ + release_version = None + """ + release_version as queried from the control connection system tables + """ + _datacenter = None _rack = None _reconnection_handler = None @@ -282,6 +301,8 @@ class HostConnection(object): self.host_distance = host_distance self._session = weakref.proxy(session) self._lock = Lock() + # this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool. + self._stream_available_condition = Condition(self._lock) self._is_replacing = False if host_distance == HostDistance.IGNORED: @@ -306,16 +327,27 @@ class HostConnection(object): if not conn: raise NoConnectionsAvailable() - with conn.lock: - if conn.in_flight < conn.max_request_id: - conn.in_flight += 1 - return conn, conn.get_request_id() + start = time.time() + remaining = timeout + while True: + with conn.lock: + if conn.in_flight <= conn.max_request_id: + conn.in_flight += 1 + return conn, conn.get_request_id() + if timeout is not None: + remaining = timeout - time.time() + start + if remaining < 0: + break + with self._stream_available_condition: + self._stream_available_condition.wait(remaining) raise NoConnectionsAvailable("All request IDs are currently in use") def return_connection(self, connection): with connection.lock: connection.in_flight -= 1 + with self._stream_available_condition: + self._stream_available_condition.notify() if (connection.is_defunct or connection.is_closed) and not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " @@ -335,12 +367,18 @@ class HostConnection(object): def _replace(self, connection): log.debug("Replacing connection (%s) to %s", id(connection), self.host) - conn = self._session.cluster.connection_factory(self.host.address) - if self._session.keyspace: - conn.set_keyspace_blocking(self._session.keyspace) - self._connection = conn - with self._lock: - self._is_replacing = False + try: + conn = self._session.cluster.connection_factory(self.host.address) + if self._session.keyspace: + conn.set_keyspace_blocking(self._session.keyspace) + self._connection = conn + except Exception: + log.warning("Failed reconnecting %s. Retrying." % (self.host.address,)) + self._session.submit(self._replace, connection) + else: + with self._lock: + self._is_replacing = False + self._stream_available_condition.notify() def shutdown(self): with self._lock: @@ -348,6 +386,7 @@ class HostConnection(object): return else: self.is_shutdown = True + self._stream_available_condition.notify_all() if self._connection: self._connection.close() @@ -499,10 +538,10 @@ class HostConnectionPool(object): max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) with self._lock: if self.is_shutdown: - return False + return True if self.open_count >= max_conns: - return False + return True self.open_count += 1 @@ -646,17 +685,22 @@ class HostConnectionPool(object): if should_replace: log.debug("Replacing connection (%s) to %s", id(connection), self.host) - - def close_and_replace(): - connection.close() - self._add_conn_if_under_max() - - self._session.submit(close_and_replace) + connection.close() + self._session.submit(self._retrying_replace) else: - # just close it log.debug("Closing connection (%s) to %s", id(connection), self.host) connection.close() + def _retrying_replace(self): + replaced = False + try: + replaced = self._add_conn_if_under_max() + except Exception: + log.exception("Failed replacing connection to %s", self.host) + if not replaced: + log.debug("Failed replacing connection to %s. Retrying.", self.host) + self._session.submit(self._retrying_replace) + def shutdown(self): with self._lock: if self.is_shutdown: diff --git a/cassandra/protocol.py b/cassandra/protocol.py index d2b94673..ede6cc58 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -22,7 +22,7 @@ import six from six.moves import range import io -from cassandra import type_codes +from cassandra import type_codes, DriverException from cassandra import (Unavailable, WriteTimeout, ReadTimeout, WriteFailure, ReadFailure, FunctionFailure, AlreadyExists, InvalidRequest, Unauthorized, @@ -589,7 +589,7 @@ class ResultMessage(_MessageType): elif kind == RESULT_KIND_SCHEMA_CHANGE: results = cls.recv_results_schema_change(f, protocol_version) else: - raise Exception("Unknown RESULT kind: %d" % kind) + raise DriverException("Unknown RESULT kind: %d" % kind) return cls(kind, results, paging_state) @classmethod @@ -971,7 +971,7 @@ class _ProtocolHandler(object): """ if flags & COMPRESSED_FLAG: if decompressor is None: - raise Exception("No de-compressor available for compressed frame!") + raise RuntimeError("No de-compressor available for compressed frame!") body = decompressor(body) flags ^= COMPRESSED_FLAG diff --git a/cassandra/query.py b/cassandra/query.py index 55d02300..1fb4d4cc 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -31,7 +31,7 @@ from cassandra.util import unix_time_from_uuid1 from cassandra.encoder import Encoder import cassandra.encoder from cassandra.protocol import _UNSET_VALUE -from cassandra.util import OrderedDict, _positional_rename_invalid_identifiers +from cassandra.util import OrderedDict, _sanitize_identifiers import logging log = logging.getLogger(__name__) @@ -117,13 +117,14 @@ def named_tuple_factory(colnames, rows): try: Row = namedtuple('Row', clean_column_names) except Exception: + clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " "(see Python 'namedtuple' documentation for details on name rules). " "Results will be returned with positional names. " "Avoid this by choosing different names, using SELECT \"\" AS aliases, " "or specifying a different row_factory on your Session" % (colnames, clean_column_names)) - Row = namedtuple('Row', _positional_rename_invalid_identifiers(clean_column_names)) + Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) return [Row(*row) for row in rows] @@ -864,7 +865,7 @@ class QueryTrace(object): self.trace_id = trace_id self._session = session - def populate(self, max_wait=2.0, wait_for_complete=True): + def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): """ Retrieves the actual tracing details from Cassandra and populates the attributes of this instance. Because tracing details are stored @@ -875,6 +876,9 @@ class QueryTrace(object): `wait_for_complete=False` bypasses the wait for duration to be populated. This can be used to query events from partial sessions. + + `query_cl` specifies a consistency level to use for polling the trace tables, + if it should be different than the session default. """ attempt = 0 start = time.time() @@ -886,7 +890,7 @@ class QueryTrace(object): log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) session_results = self._execute( - self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait) + SimpleStatement(self._SELECT_SESSIONS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) is_complete = session_results and session_results[0].duration is not None if not session_results or (wait_for_complete and not is_complete): @@ -910,7 +914,7 @@ class QueryTrace(object): log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) time_spent = time.time() - start event_results = self._execute( - self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait) + SimpleStatement(self._SELECT_EVENTS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) log.debug("Fetched trace events for trace ID: %s", self.trace_id) self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) for r in event_results) diff --git a/cassandra/util.py b/cassandra/util.py index ab6968e0..ec590e4b 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -1101,7 +1101,7 @@ else: WSAAddressToStringA = ctypes.windll.ws2_32.WSAAddressToStringA else: def not_windows(*args): - raise Exception("IPv6 addresses cannot be handled on Windows. " + raise OSError("IPv6 addresses cannot be handled on Windows. " "Missing ctypes.windll") WSAStringToAddressA = not_windows WSAAddressToStringA = not_windows @@ -1171,3 +1171,14 @@ def _positional_rename_invalid_identifiers(field_names): or name.startswith('_')): names_out[index] = 'field_%d_' % index return names_out + + +def _sanitize_identifiers(field_names): + names_out = _positional_rename_invalid_identifiers(field_names) + if len(names_out) != len(set(names_out)): + observed_names = set() + for index, name in enumerate(names_out): + while names_out[index] in observed_names: + names_out[index] = "%s_" % (names_out[index],) + observed_names.add(names_out[index]) + return names_out diff --git a/docs/api/cassandra.rst b/docs/api/cassandra.rst index 5628099a..fd4273f4 100644 --- a/docs/api/cassandra.rst +++ b/docs/api/cassandra.rst @@ -22,6 +22,12 @@ :members: :inherited-members: +.. autoexception:: DriverException() + :members: + +.. autoexception:: RequestExecutionException() + :members: + .. autoexception:: Unavailable() :members: @@ -34,6 +40,9 @@ .. autoexception:: WriteTimeout() :members: +.. autoexception:: CoordinationFailure() + :members: + .. autoexception:: ReadFailure() :members: @@ -43,6 +52,12 @@ .. autoexception:: FunctionFailure() :members: +.. autoexception:: RequestValidationException() + :members: + +.. autoexception:: ConfigurationException() + :members: + .. autoexception:: AlreadyExists() :members: diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index 9c546d3b..fdea3e67 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -25,6 +25,8 @@ .. autoattribute:: conviction_policy_factory + .. autoattribute:: address_translator + .. autoattribute:: connection_class .. autoattribute:: metrics_enabled @@ -49,6 +51,12 @@ .. autoattribute:: connect_timeout + .. autoattribute:: schema_metadata_enabled + :annotation: = True + + .. autoattribute:: token_metadata_enabled + :annotation: = True + .. automethod:: connect .. automethod:: shutdown diff --git a/docs/api/cassandra/policies.rst b/docs/api/cassandra/policies.rst index 44346c4b..c96e491a 100644 --- a/docs/api/cassandra/policies.rst +++ b/docs/api/cassandra/policies.rst @@ -24,6 +24,15 @@ Load Balancing .. autoclass:: TokenAwarePolicy :members: +Translating Server Node Addresses +--------------------------------- + +.. autoclass:: AddressTranslator + :members: + +.. autoclass:: IdentityTranslator + :members: + Marking Hosts Up or Down ------------------------ diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 4b3afaa0..cd557575 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -86,6 +86,15 @@ def _tuple_version(version_string): USE_CASS_EXTERNAL = bool(os.getenv('USE_CASS_EXTERNAL', False)) +# If set to to true this will force the Cython tests to run regardless of whether they are installed +cython_env = os.getenv('VERIFY_CYTHON', "False") + + +VERIFY_CYTHON = False + +if(cython_env == 'True'): + VERIFY_CYTHON = True + default_cassandra_version = '2.2.0' @@ -350,6 +359,22 @@ def execute_with_long_wait_retry(session, query, timeout=30): raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) +def execute_with_retry_tolerant(session, query, retry_exceptions, escape_exception): + # TODO refactor above methods into this one for code reuse + tries = 0 + while tries < 100: + try: + tries += 1 + rs = session.execute(query) + return rs + except escape_exception: + return + except retry_exceptions: + time.sleep(.1) + + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + + def drop_keyspace_shutdown_cluster(keyspace_name, session, cluster): try: execute_with_long_wait_retry(session, "DROP KEYSPACE {0}".format(keyspace_name)) diff --git a/tests/integration/long/ssl/.keystore b/tests/integration/long/ssl/.keystore new file mode 100644 index 00000000..58ab9696 Binary files /dev/null and b/tests/integration/long/ssl/.keystore differ diff --git a/tests/integration/long/ssl/.truststore b/tests/integration/long/ssl/.truststore new file mode 100644 index 00000000..80bbc1fd Binary files /dev/null and b/tests/integration/long/ssl/.truststore differ diff --git a/tests/integration/long/ssl/cassandra.crt b/tests/integration/long/ssl/cassandra.crt deleted file mode 100644 index 432e5854..00000000 Binary files a/tests/integration/long/ssl/cassandra.crt and /dev/null differ diff --git a/tests/integration/long/ssl/cassandra.pem b/tests/integration/long/ssl/cassandra.pem new file mode 100644 index 00000000..43c0a3e4 --- /dev/null +++ b/tests/integration/long/ssl/cassandra.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDnzCCAoegAwIBAgIEG7jtLDANBgkqhkiG9w0BAQsFADB/MQswCQYDVQQGEwJVUzETMBEGA1UE +CBMKQ2FsaWZvcm5pYTEUMBIGA1UEBxMLU2FudGEgQ2xhcmExFjAUBgNVBAoTDURhdGFTdGF4IElu +Yy4xGTAXBgNVBAsTEFBIUCBEcml2ZXIgVGVzdHMxEjAQBgNVBAMTCTEyNy4wLjAuMTAgFw0xNjA0 +MTkxNTIzNDBaGA8yMTE2MDMyNjE1MjM0MFowfzELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlm +b3JuaWExFDASBgNVBAcTC1NhbnRhIENsYXJhMRYwFAYDVQQKEw1EYXRhU3RheCBJbmMuMRkwFwYD +VQQLExBQSFAgRHJpdmVyIFRlc3RzMRIwEAYDVQQDEwkxMjcuMC4wLjEwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCkFMlcxnl+k5KAzt6g1GSQ/kXzHtQXcf0//x6BQTRYdOEeiBnMcI+o +HYefiwGpDslr3YxcWwJSfKjgUhGj+2wyl4O8PP0Up47tX5HUQuIlIjjBZd465VhQh6DqaJky3YHX +KD+8eHuVMnEyAImsNh9laQkUOHsGT/POpI77IBxS1hVVOu6A5bYz17D0RAzZBel3eZBWLSSgbkSG +jUPIDY+078qRJI56xY/6lEM5Zr6DJ96jTdqjRPFv3fHJJZnwCNz0ng0wB/gHYFkm2fdGAM2jrCdD +jC+VZK6uyXatRbFanPdlfZ4rWPrH7V0c6wrDssuUMlDIdaMyHC89axZLP5ZZAgMBAAGjITAfMB0G +A1UdDgQWBBR/0GR2jGRuP8rsmBept17gwXcRejANBgkqhkiG9w0BAQsFAAOCAQEAe33a2GITd7DP +WmDCFfsZqCDIbHdx1QwfcvWRpSnjXpN4Muzrt4TCRJAx6kNQZDnG2aNMdGWGjrQAXUMLkJfSnwJX +X1f3LteNfekWB/lN6kVbPy5lSuLT45q3xSLpTUE51S3zG/c+qyi3FZgYA729ieuDW8VTQdF9hClN +Ik8Wy5Gg87AdYY71DvG9IQkg9eAGVeZSnfMUDYpAiZaavsYJwFfYqIT/WCz9FNXPYN1tDNoV3u71 +GTPrywMLZlmmGn3EBvbh1mR25EmPkxAw6LkWyfbbzrskhXmzN+j9TZNN9MiYMNtisWBR9afsVpLq +Bw4ukuih6CyqUxwuVBNhns8iFA== +-----END CERTIFICATE----- diff --git a/tests/integration/long/ssl/driver.crt b/tests/integration/long/ssl/driver.crt new file mode 100644 index 00000000..76c99990 Binary files /dev/null and b/tests/integration/long/ssl/driver.crt differ diff --git a/tests/integration/long/ssl/driver.key b/tests/integration/long/ssl/driver.key new file mode 100644 index 00000000..bcb033d3 --- /dev/null +++ b/tests/integration/long/ssl/driver.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAohKfXUHM5AQvwgTPIretqenYVJzo78z5IE/xjUe0/e24lJCe +qCYaJPHB3kjg05ZQhBWZFsPhxMwh7zvTaLiIfjK9EuDvBP+dhxLbWShXcL8NNu45 +dfM4BebewR82Kk4rAMHgAaWDPjnyo4XJ+DUMlXMzbMrCT35k0R2ddwuEF/OI6dIn +AX/VVDtOaWaiCP6BC4rt9r2y1UjJ5xHAK+bPjev7Y1TrSd8maTQTFzORlVRamwpQ +tMXDJfWyan5+rbAfi74Qk6kBuc25ZWKpYKJlKGgHy9b9875G76K63boXe8hDcDE1 +mJHOEReb/n533OWmtTUH2z9Wz/ZLfE7E784DtQIDAQABAoIBAEjCDXMDMCUI0BvM +puSSCs9loG1Tx6qnnQtpeCoRj+oTBzMCvYcFG+WL9pmApBW4vEqfm4tBwp/jeDR6 +6gxwEy58Pq8z9XmL+z8BMyGSX7bIvfjz9y3BU1wq6cNO7LUPRHThlFuI3yNXRlss +m5/myNBq2CI14Adp2X6bSe6AZ/uL6+Ez8ZLpNiMUtGzTDILyH8bkgRHsyDxoDpF+ +h1r42dZG3ue4JC3oqUE2epvNE5Rzh02OTBSYYbmOmFWkhV2lm6DH46lycFHjkzcm +2meU7q9I7nZrf0+ZkkQXhWGQWcGzHd3si+XMhaHT0BnH8gGFZDzpC/qnvZdIn//y +zDu2B/0CgYEA43kkgZ1YGqi5Vi4ngek/uk4sJ9dOvYIaAag+u5ybRinLhCFDGhr9 +MzksqFY7naq9oeZZrLbftw2FsKsCMYAOrrxZIbhvGm538jwWmBefEmD0Ww+k/WKG +AAv0z4sSnSOq2/sZBQJlOrk8K6wQ+FcyEddiy16zyj9UCzEPSDQoNj8CgYEAtmXY +xYALbGz0peDGfzPq/WjeqCJtjlXIKOB7+vdyi88YE4Tea2vfBmAAWmbTf/el/o8R +txDfP6+RKxXs2jCFCY640W83ar3RjXcwtu+oH4Aqa8S21o6Dx2sx3HOgKoJ3DGGB +HHmYczuDZmN/zHseBrYUf1UNcIDQIWZCLKImkQsCgYEAuY735bfJyC18yr1pwwLX +1o2wXWu4ssSvnEx3qCmvTIQJnaZYi7IafC+Wq1d02EAQ40H6QvcG9ddVCHHnnyOc +VpQUjXpbP49/zx2PPNCAFcj7bFENwRkewGkAclE7hO40kbB6j2iN1WKHoASD72GJ +2Z3+3HFLbp9MWrjRg4/wgZcCgYAk2IgkxYwJOC1nwPN4SM2kqWWpQ2MsSKnpkAeb +mBccpRYxAztHKgBgsk3/9RuukyGGPfKPL6pZnCbQNFqnbPvDBYDSVgw01OmbEUPX +AKzOyD5JjPB+gUWfqEjnRrhJPhe8eYnybaHdTV1q9piffxN+uZOEcXMIkgz5YkXl +7E+sJwKBgFLA2CS19lbcoi5R8AJbUydZJr/LNAbFknWnliiq3GxISfWydA9cG/dI +CxV3297awLhHrCWqziC0zITjEcAhsNkfG/VQlYGJOS3sfvMSrLuW/9bAL8o4VCpC +cOs9e+svbJukJB6UQu4vpROMmv+0quXM325VlCZNel7DPAovYwjW +-----END RSA PRIVATE KEY----- diff --git a/tests/integration/long/ssl/driver.pem b/tests/integration/long/ssl/driver.pem new file mode 100644 index 00000000..f5aaa25b --- /dev/null +++ b/tests/integration/long/ssl/driver.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDnzCCAoegAwIBAgIEJErKsDANBgkqhkiG9w0BAQsFADB/MQswCQYDVQQGEwJVUzETMBEGA1UE +CBMKQ2FsaWZvcm5pYTEUMBIGA1UEBxMLU2FudGEgQ2xhcmExFjAUBgNVBAoTDURhdGFTdGF4IElu +Yy4xGTAXBgNVBAsTEFBIUCBEcml2ZXIgVGVzdHMxEjAQBgNVBAMTCTEyNy4wLjAuMTAgFw0xNjA0 +MTkxNTIzNDVaGA8yMTE2MDMyNjE1MjM0NVowfzELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlm +b3JuaWExFDASBgNVBAcTC1NhbnRhIENsYXJhMRYwFAYDVQQKEw1EYXRhU3RheCBJbmMuMRkwFwYD +VQQLExBQSFAgRHJpdmVyIFRlc3RzMRIwEAYDVQQDEwkxMjcuMC4wLjEwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCiEp9dQczkBC/CBM8it62p6dhUnOjvzPkgT/GNR7T97biUkJ6oJhok +8cHeSODTllCEFZkWw+HEzCHvO9NouIh+Mr0S4O8E/52HEttZKFdwvw027jl18zgF5t7BHzYqTisA +weABpYM+OfKjhcn4NQyVczNsysJPfmTRHZ13C4QX84jp0icBf9VUO05pZqII/oELiu32vbLVSMnn +EcAr5s+N6/tjVOtJ3yZpNBMXM5GVVFqbClC0xcMl9bJqfn6tsB+LvhCTqQG5zbllYqlgomUoaAfL +1v3zvkbvorrduhd7yENwMTWYkc4RF5v+fnfc5aa1NQfbP1bP9kt8TsTvzgO1AgMBAAGjITAfMB0G +A1UdDgQWBBR8aJLDSgkUMcrs08BbxhRA1wJIIzANBgkqhkiG9w0BAQsFAAOCAQEAoHRggyaMKbeB +633sZgzH8DzvngzA/vID+XWAv+lCGdIYNkbu9VJ8IaYsa9JsMvvhp7UFL1mYm32QacjorxqfNTNS +To8z4VOXrGLYkuJL1M2qJjkl3ehkX8tzKXyDIgq4pVCvKkFZR0It+QU87MnHUL1/HIOy+zdNW6ZU +Q7sRCUMtqstiQ4scbispsVevfEBkGNjHIp6M/5Qe6skntRvdNMWZILz82GLym+NppTgcNcwDf7lq +g/syNznM0KAE1VUAJ2y8tArvAZ/XugC2RmZGwY3q/qw1B7kaoTqu7KSdLuWzol5gR0NNVADU5x7U +BrwBmgliT/bGpRz+PAGVYj5OXw== +-----END CERTIFICATE----- diff --git a/tests/integration/long/ssl/driver_ca_cert.pem b/tests/integration/long/ssl/driver_ca_cert.pem deleted file mode 100644 index 7e555557..00000000 --- a/tests/integration/long/ssl/driver_ca_cert.pem +++ /dev/null @@ -1,16 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICbjCCAdegAwIBAgIEP/N06DANBgkqhkiG9w0BAQsFADBqMQswCQYDVQQGEwJU -RTELMAkGA1UECBMCQ0ExFDASBgNVBAcTC1NhbnRhIENsYXJhMREwDwYDVQQKEwhE -YXRhU3RheDELMAkGA1UECxMCVEUxGDAWBgNVBAMTD1BoaWxpcCBUaG9tcHNvbjAe -Fw0xNDEwMDMxNTQ2NDdaFw0xNTAxMDExNTQ2NDdaMGoxCzAJBgNVBAYTAlRFMQsw -CQYDVQQIEwJDQTEUMBIGA1UEBxMLU2FudGEgQ2xhcmExETAPBgNVBAoTCERhdGFT -dGF4MQswCQYDVQQLEwJURTEYMBYGA1UEAxMPUGhpbGlwIFRob21wc29uMIGfMA0G -CSqGSIb3DQEBAQUAA4GNADCBiQKBgQChGDwrhpQR0d+NoqilMgsBlR6A2Dd1oMyI -Ue42sU4tN63g5N4adasfasdfsWgnAkP332ok3YAuVbxytwEv2K9HrUSiokAiuinl -hhHA8CXTHt/1ItzzWj9uJ3Hneb+5lOkXVTZX7Y+q3aSdpx/HnZqn4i27DtLZF0z3 -LccWPWRinQIDAQABoyEwHzAdBgNVHQ4EFgQU9WJpUhgGTBBH4xZBCV7Y9YISCp4w -DQYJKoZIhvcNAQELBQADgYEAF6e8eVAjoZhfyJ+jW5mB0pXa2vr5b7VFQ45voNnc -GrB3aNbz/AWT7LCJw88+Y5SJITgwN/8o3ZY6Y3MyiqeQYGo9WxDSWb5AdZWFa03Z -+hrVDQuw1r118zIhdS4KYDCQM2JfWY32TwK0MNG/6BO876HfkDpcjCYzq8Gh0gEg -uOA= ------END CERTIFICATE----- diff --git a/tests/integration/long/ssl/python_driver.crt b/tests/integration/long/ssl/python_driver.crt deleted file mode 100644 index 0a419f4e..00000000 Binary files a/tests/integration/long/ssl/python_driver.crt and /dev/null differ diff --git a/tests/integration/long/ssl/python_driver.jks b/tests/integration/long/ssl/python_driver.jks deleted file mode 100644 index 9a0fd59f..00000000 Binary files a/tests/integration/long/ssl/python_driver.jks and /dev/null differ diff --git a/tests/integration/long/ssl/python_driver.key b/tests/integration/long/ssl/python_driver.key deleted file mode 100644 index afd73b29..00000000 --- a/tests/integration/long/ssl/python_driver.key +++ /dev/null @@ -1,34 +0,0 @@ -Bag Attributes - friendlyName: python_driver - localKeyID: 54 69 6D 65 20 31 34 33 35 33 33 33 34 30 34 33 33 32 -Key Attributes: ------BEGIN RSA PRIVATE KEY----- -Proc-Type: 4,ENCRYPTED -DEK-Info: DES-EDE3-CBC,8A0BC9CFBBB36D47 - -J3Rh82LhsNdIdCV4KCp758VIJJnmedwtq/I9oxH5kY4XoUQjfNcvLGlEnbAUD6+N -mYnQ5XPDvD7iC19XvlA9gfaoWERq+zroGEP+e4dX1X5RlT6YQBJpJR8IW4DWngDM -Nv6CuaGFJWMH8QUvKlJyFOPOHBqbhsCRaxg3pOG3RyUFXpGPDV0ySUyp6poHE9KE -pEVif/SdS3AhV2sb4tyBS9sRZdH1eeCN4gY6k9PQWyNViAgUYAG5xWsE4fITa3qY -gisyzbOYU8ue2QvmjPJgieiKPQf+st/ZRV5eQUCdUgAfLEnULGJXRZ5kw7kMXL0X -gLaKFbGxv4pKQCDCZQq4GXIA/nmTy6cme+VPKwq3usm+GdxfdWQJjgG65+AFaut/ -XjGm1fvSQzWuzpesfLy57HMK+bBh1/UKjuQa3wAHtgPtJLtUSW+/qBnQRdBbl20C -dJtJXyyTlX6H8bQBIfBLc4ntUwS8fVd2jsYJRpCBY6HdtpfsZZ5gQnm1Nux4ksKn -rYYx3I/JpChr7AV7Yj/lwc3Zca/VJl16CjyWeRTQEvkl6GK958aIzj73HfXleZc6 -HGVfOgo2BLmOzY0ZCq/Wa4fnURRgrC3SusrT9mjVbID91oNYw4BjMEU53u0uxPC+ -rr6SwG2EUVawGTVK4XZw2DINCPP/wsKqf0xqA+sxArcTN/MEdLUBdf8oDntkj2jG -Oy0kwpjqhSvWo1DqYKZjV/wKT2SS18OMAW+0qplbHw1/FDGWK+OseD8MXwBo06a5 -LWRQXhf0kEXUQ+oNj3eahe/npHiNChR6mEiIbCuE3NAXPPXJNkhMuj2f5EqrOPfZ -jqbNiLfKKx7L5t6B8LXkdKGPqztcFlnB8rRF9Eqa8F4wiEg8MBLrPyxgd/uT+NIz -LdDgvUE+IkCwQoYoCU70ApiEOyQNacuSxwUiVWVyn9CJYXPM4Vlje7GDIDRR5Xp6 -zNf0ktNP46PsRqDlYG9hZWndj4PRaAqtatlEEm37rmyouVBe3rxcbL1b1zsH/p1I -eaGGTyZ8+iEiuEk4gCOmfmYmpE7H/DXlQvtDRblid/bEY64Uietx0HQ5yZwXZYi8 -hb4itke6xkgRQEIXVyQOdU88PEuA5yofEGoXkfdLgtdu3erPrVDc+nQTYrMWNacR -JQljfhAFJdjOw81Yd5PnFHAtxcxzqEkWv0TGQLL1VjJdinhI7q/fIPLJ76FtuGmt -zlxo/Jy1aaUgM/e485+7aoNSGi2/t6zGqGuotdUCO5epgrUHX+9fOJnnrYTG9ixp -FSHTT69y72khnw3eMP8NnOS3Lu+xLEzQHNbUDfB8uyVEX4pyA3FPVVqwIaeJDiPS -2x7Sl5KKwLbqPPKRFRC1qLsN4KcqeXBG+piTLPExdzsLbrU9JZMcaNmSmUabdg20 -SCwIuU2kHEpO7O7yNGeV9m0CGFUaoCAHVG70oXHxpVjAJbtgyoBkiwSxghCxXkfW -Mg+1B2k4Gk1WrLjIyasH6p0MLUJ7qLYN+c+wF7ms00F/w04rM6zUpkgnqsazpw6F -weUhpA8qY2vOJN6rsB4byaOUnd33xhAwcY/pIAcjW7UBjNmFMB1DQg== ------END RSA PRIVATE KEY----- diff --git a/tests/integration/long/ssl/python_driver.pem b/tests/integration/long/ssl/python_driver.pem deleted file mode 100644 index 83556fd9..00000000 --- a/tests/integration/long/ssl/python_driver.pem +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDrzCCApegAwIBAgIEFPORBzANBgkqhkiG9w0BAQsFADCBhjELMAkGA1UEBhMCVVMxEzARBgNV -BAgTCkNhbGlmb3JuaWExFDASBgNVBAcTC1NhbnRhIENsYXJhMRYwFAYDVQQKEw1EYXRhU3RheCBJ -bmMuMRwwGgYDVQQLExNQeXRob24gRHJpdmVyIFRlc3RzMRYwFAYDVQQDEw1QeXRob24gRHJpdmVy -MCAXDTE1MDYyNTE3MDAxOFoYDzIxMTUwNjAxMTcwMDE4WjCBhjELMAkGA1UEBhMCVVMxEzARBgNV -BAgTCkNhbGlmb3JuaWExFDASBgNVBAcTC1NhbnRhIENsYXJhMRYwFAYDVQQKEw1EYXRhU3RheCBJ -bmMuMRwwGgYDVQQLExNQeXRob24gRHJpdmVyIFRlc3RzMRYwFAYDVQQDEw1QeXRob24gRHJpdmVy -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAjUi6xfmieLqx9yD7HhkB6sjHfbS51xE7 -aaRySjTA1p9mPfPMPPMZ0NIsiDsUlT7Fa+LWZU9cGuJBFg8YxjU5Eij7smFd0J4tawk51KudDdUZ -crALGFC3WY7sEboMl8UHqV+kESPlNm5/JSNSYkNm1TMi9mHcB/Bg3RDpORRW/keMtBSLRxCVjsu6 -GvKN8wuEfU/bTmI9aUjbFRCFunBX6QEJeU44BYEJXNAls+X8szBfVmFHwefatSlh++uu7kY6zAQI -v74PHMZ8w+mWmbjpxEsmSg+uljGCjQHjKTNSFBY9kWWh2LBiTcZuEsQ9DK0J/+1tUa0s5vq6CjUK -XRxwpQIDAQABoyEwHzAdBgNVHQ4EFgQUJwTYG8dcZDt7faalYwCHmG3jp3swDQYJKoZIhvcNAQEL -BQADggEBABtg3SLFUkcbISoZO4/UdHY2z4BTJZXt5uep9qIVQu7NospzsafgyGF0YAQJq0fLhBlB -DVx6IxIvDZUfzKdIVMYJTQh7ZJ7kdsdhcRIhKZK4Lko3iOwkWS0aXsbQP+hcXrwGViYIV6+Rrmle -LuxwexVfJ+wXCJcc4vvbecVsOs2+ms1w98cUXvVS1d9KpHo37LK1mRsnYPik3+CBeYXqa8FzMJc1 -dlC/dNwrCXYJZ1QMEpyaP4TI3fmkg8OJ3glZkQr6nz1TUMwMmAvudb79IrmQKBuO6k99DZFJC6Er -oh6ff8G/F5YY+dWEqsF0KqNhL9uwyrqG3CTX5Eocg2AGkWI= ------END CERTIFICATE----- diff --git a/tests/integration/long/ssl/python_driver_bad.pem b/tests/integration/long/ssl/python_driver_bad.pem deleted file mode 100644 index 978d6c53..00000000 --- a/tests/integration/long/ssl/python_driver_bad.pem +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDrzCCApegAwIBAgIEFPORBzANBgkqhkiG9w0BAQsFADCBhjELMAkGA1UEBhMCVVMxEzARBgNV -BAgTCkNhbGlmb3JuaWExFDASBgNVBAcTC1NhbnRhIENsYXJhMRYwFAYDVQQKEw1EYXRhU3RheCBJ -bmMuMRwwGgYDVQQLExNQeXRob24gRHJpdmVyIFRlc3RzMRYwFAYDVQQDEw1QeXRob24gRHJpdmVy -MCAXDTE1MDYyNTE3MDAxOFoYDzIxMTUwNjAxMTcwMDE4WjCBhjELMAkGA1UEBhMCVVMxEzARBgNV -BAgTCkNhbGlmb3JuaWExFDASBgNVBAcTC1NhbnRhIENsYXJhMRYwFAYDVQQKEw1EYXRhU3RheCBJ -bmMuMRwwGgYDVQQLExNQeXRob24gRHJpdmVyIFRlc3RzMRYwFAYDVQQDEw1QeXRob24gRHJpdmVy -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAjUi6xfmieLqx9yD7HhkB6sjHfbS51xE7 -aaRySjTA1p9mPfPMPPMZ0NIsiDsUlT7Fa+LWZU9cGuJBFg8YxjU5Eij7smFd0J4tawk51KudDdUZ -crALGFC3WY7sEboMl8UHqV+kESPlNm5/JSNSYkNm1TMi9mHcB/Bg3RDpORRW/keMtBSLRxCVjsu6 -GvKN8wuEfU/bTmI9aUjbFRCFunBX6QEJeU44BYEJXNAls+X8szBfVmFHwefatSlh++uu7kY6zAQI -v74PHMZ8w+mWmbjpxEsmSg+uljGCjQHjKTNSFBY9kWWh2LBiTcZuEsQ9DK0J/+1tUa0s5vq6CjUK -XRxwpQIDAQABoyE666666gNVHQ4EFgQUJwTYG8dcZDt7faalYwCHmG3jp3swDQYJKoZIhvcNAQEL -BQADggEBABtg3SLFUkcbISoZO4/UdHY2z4BTJZXt5uep9qIVQu7NospzsafgyGF0YAQJq0fLhBlB -DVx6IxIvDZUfzKdIVMYJTQh7ZJ7kdsdhcRIhKZK4Lko3iOwkWS0aXsbQP+hcXrwGViYIV6+Rrmle -LuxwexVfJ+wXCJcc4vvbecVsOs2+ms1w98cUXvVS1d9KpHo37LK1mRsnYPik3+CBeYXqa8FzMJc1 -dlC/dNwrCXYJZ1QMEpyaP4TI3fmkg8OJ3glZkQr6nz1TUMwMmAvudb79IrmQKBuO6k99DZFJC6Er -oh6ff8G/F5YY+dWEqsF0KqNhL9uwyrqG3CTX5Eocg2AGkWI= ------END CERTIFICATE----- diff --git a/tests/integration/long/ssl/python_driver_no_pass.key b/tests/integration/long/ssl/python_driver_no_pass.key deleted file mode 100644 index 8dd14f84..00000000 --- a/tests/integration/long/ssl/python_driver_no_pass.key +++ /dev/null @@ -1,27 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIIEowIBAAKCAQEAjUi6xfmieLqx9yD7HhkB6sjHfbS51xE7aaRySjTA1p9mPfPM -PPMZ0NIsiDsUlT7Fa+LWZU9cGuJBFg8YxjU5Eij7smFd0J4tawk51KudDdUZcrAL -GFC3WY7sEboMl8UHqV+kESPlNm5/JSNSYkNm1TMi9mHcB/Bg3RDpORRW/keMtBSL -RxCVjsu6GvKN8wuEfU/bTmI9aUjbFRCFunBX6QEJeU44BYEJXNAls+X8szBfVmFH -wefatSlh++uu7kY6zAQIv74PHMZ8w+mWmbjpxEsmSg+uljGCjQHjKTNSFBY9kWWh -2LBiTcZuEsQ9DK0J/+1tUa0s5vq6CjUKXRxwpQIDAQABAoIBAC3bpYQM+wdk0c79 -DYU/aLfkY5wRxSBhn38yuUYMyWrgYjdJoslFvuNg1MODKbMnpLzX6+8GS0cOmUGn -tMrhC50xYEEOCX1lWiib3gGBkoCi4pevPGqwCFMxaL54PQ4mDc6UFJTbqdJ5Gxva -0yrB5ebdqkN+kASjqU0X6Bt21qXB6BvwAgpIXSX8r+NoH2Z9dumSYD+bOwhXo+/b -FQ1wyLL78tDdlJ8KibwnTv9RtLQbALUinMEHyP+4Gp/t/JnxlcAfvEwggYBxFR1K -5sN8dMFbMZVNqNREXZyWCMQqPbKLhIHPHlNo5pJP7cUh9iVH4QwYNIbOqUza/aUx -z7DIISECgYEAvpAAdDiBExMOELz4+ku5Uk6wmVOMnAK6El4ijOXjJsOB4FB6M0A6 -THXlzLws0YLcoZ3Pm91z20rqmkv1VG+En27uKC1Dgqqd4DOQzMuPoPxzq/q2ozFH -V5U1a0tTmyynr3CFzQUJKLJs1pKKIp6HMiB48JWQc5q6ZaaomEnOiYsCgYEAvczB -Bwwf7oaZGhson1HdcYs5kUm9VkL/25dELUt6uq5AB5jjvfOYd7HatngNRCabUCgE -gcaNfJSwpbOEZ00AxKVSxGmyIP1YAlkVcSdfAPwGO6C1+V4EPHqYUW0AVHOYo7oB -0MCyLT6nSUNiHWyI7qSEwCP03SqyAKA1pDRUVI8CgYBt+bEpYYqsNW0Cn+yYlqcH -Jz6n3h3h03kLLKSH6AwlzOLhT9CWT1TV15ydgWPkLb+ize6Ip087mYq3LWsSJaHG -WUC8kxLJECo4v8mrRzdG0yr2b6SDnebsVsITf89qWGUVzLyLS4Kzp/VECCIMRK0F -ctQZFFffP8ae74WRDddSbQKBgQC7vZ9qEyo6zNUAp8Ck51t+BtNozWIFw7xGP/hm -PXUm11nqqecMa7pzG3BWcaXdtbqHrS3YGMi3ZHTfUxUzAU4zNb0LH+ndC/xURj4Z -cXJeDO01aiDWi5LxJ+snEAT1hGqF+WX2UcVtT741j/urU0KXnBDb5jU92A++4rps -tH5+LQKBgGHtOWD+ffKNw7IrVLhP16GmYoZZ05zh10d1eUa0ifgczjdAsuEH5/Aq -zK7MsDyPcQBH/pOwAcifWGEdXmn9hL6w5dn96ABfa8Qh9nXWrCE2OFD81PDU9Osd -wnwbTKlYWPBwdF7UCseKC7gXkUD6Ls0ADWJvrCI7AfQJv6jj6nnE ------END RSA PRIVATE KEY----- diff --git a/tests/integration/long/ssl/server_cert.pem b/tests/integration/long/ssl/server_cert.pem deleted file mode 100644 index 7c96b96a..00000000 --- a/tests/integration/long/ssl/server_cert.pem +++ /dev/null @@ -1,13 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICbjCCAdegAwIBAgIEP/N06DANBgkqhkiG9w0BAQsFADBqMQswCQYDVQQGEwJURTELMAkGA1UE -CBMCQ0ExFDASBgNVBAcTC1NhbnRhIENsYXJhMREwDwYDVQQKEwhEYXRhU3RheDELMAkGA1UECxMC -VEUxGDAWBgNVBAMTD1BoaWxpcCBUaG9tcHNvbjAeFw0xNDEwMDMxNTQ2NDdaFw0xNTAxMDExNTQ2 -NDdaMGoxCzAJBgNVBAYTAlRFMQswCQYDVQQIEwJDQTEUMBIGA1UEBxMLU2FudGEgQ2xhcmExETAP -BgNVBAoTCERhdGFTdGF4MQswCQYDVQQLEwJURTEYMBYGA1UEAxMPUGhpbGlwIFRob21wc29uMIGf -MA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQChGDwrhpQR0d+NoqilMgsBlR6A2Dd1oMyIUe42sU4t -N63g5N44Ic4RpTiyaWgnAkP332ok3YAuVbxytwEv2K9HrUSiokAiuinlhhHA8CXTHt/1ItzzWj9u -J3Hneb+5lOkXVTZX7Y+q3aSdpx/HnZqn4i27DtLZF0z3LccWPWRinQIDAQABoyEwHzAdBgNVHQ4E -FgQU9WJpUhgGTBBH4xZBCV7Y9YISCp4wDQYJKoZIhvcNAQELBQADgYEAF6e8eVAjoZhfyJ+jW5mB -0pXa2vr5b7VFQ45voNncGrB3aNbz/AWT7LCJw88+Y5SJITgwN/8o3ZY6Y3MyiqeQYGo9WxDSWb5A -dZWFa03Z+hrVDQuw1r118zIhdS4KYDCQM2JfWY32TwK0MNG/6BO876HfkDpcjCYzq8Gh0gEguOA= ------END CERTIFICATE----- diff --git a/tests/integration/long/ssl/server_keystore.jks b/tests/integration/long/ssl/server_keystore.jks deleted file mode 100644 index 81259355..00000000 Binary files a/tests/integration/long/ssl/server_keystore.jks and /dev/null differ diff --git a/tests/integration/long/ssl/server_trust.jks b/tests/integration/long/ssl/server_trust.jks deleted file mode 100644 index feb0784a..00000000 Binary files a/tests/integration/long/ssl/server_trust.jks and /dev/null differ diff --git a/tests/integration/long/test_ssl.py b/tests/integration/long/test_ssl.py index fc6a7066..9ae84f38 100644 --- a/tests/integration/long/test_ssl.py +++ b/tests/integration/long/test_ssl.py @@ -25,16 +25,16 @@ from tests.integration import use_singledc, PROTOCOL_VERSION, get_cluster, remov log = logging.getLogger(__name__) -DEFAULT_PASSWORD = "cassandra" +DEFAULT_PASSWORD = "pythondriver" # Server keystore trust store locations -SERVER_KEYSTORE_PATH = "tests/integration/long/ssl/server_keystore.jks" -SERVER_TRUSTSTORE_PATH = "tests/integration/long/ssl/server_trust.jks" +SERVER_KEYSTORE_PATH = "tests/integration/long/ssl/.keystore" +SERVER_TRUSTSTORE_PATH = "tests/integration/long/ssl/.truststore" # Client specific keys/certs -CLIENT_CA_CERTS = 'tests/integration/long/ssl/driver_ca_cert.pem' -DRIVER_KEYFILE = "tests/integration/long/ssl/python_driver_no_pass.key" -DRIVER_CERTFILE = "tests/integration/long/ssl/python_driver.pem" +CLIENT_CA_CERTS = 'tests/integration/long/ssl/cassandra.pem' +DRIVER_KEYFILE = "tests/integration/long/ssl/driver.key" +DRIVER_CERTFILE = "tests/integration/long/ssl/driver.pem" DRIVER_CERTFILE_BAD = "tests/integration/long/ssl/python_driver_bad.pem" @@ -78,6 +78,38 @@ def teardown_module(): remove_cluster() +def validate_ssl_options(ssl_options): + # find absolute path to client CA_CERTS + tries = 0 + while True: + if tries > 5: + raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") + try: + cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options=ssl_options) + session = cluster.connect() + break + except Exception: + ex_type, ex, tb = sys.exc_info() + log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + # attempt a few simple commands. + insert_keyspace = """CREATE KEYSPACE ssltest + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} + """ + statement = SimpleStatement(insert_keyspace) + statement.consistency_level = 3 + session.execute(statement) + + drop_keyspace = "DROP KEYSPACE ssltest" + statement = SimpleStatement(drop_keyspace) + statement.consistency_level = ConsistencyLevel.ANY + session.execute(statement) + + cluster.shutdown() + + class SSLConnectionTests(unittest.TestCase): @classmethod @@ -102,36 +134,32 @@ class SSLConnectionTests(unittest.TestCase): # find absolute path to client CA_CERTS abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + ssl_options = {'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl.PROTOCOL_TLSv1} + validate_ssl_options(ssl_options=ssl_options) - tries = 0 - while True: - if tries > 5: - raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") - try: - cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options={'ca_certs': abs_path_ca_cert_path, - 'ssl_version': ssl.PROTOCOL_TLSv1}) - session = cluster.connect() - break - except Exception: - ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) - del tb - tries += 1 + def test_can_connect_with_ssl_ca_host_match(self): + """ + Test to validate that we are able to connect to a cluster using ssl, and host matching - # attempt a few simple commands. - insert_keyspace = """CREATE KEYSPACE ssltest - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} - """ - statement = SimpleStatement(insert_keyspace) - statement.consistency_level = 3 - session.execute(statement) + test_can_connect_with_ssl_ca_host_match performs a simple sanity check to ensure that we can connect to a cluster with ssl + authentication via simple server-side shared certificate authority. It also validates that the host ip matches what is expected - drop_keyspace = "DROP KEYSPACE ssltest" - statement = SimpleStatement(drop_keyspace) - statement.consistency_level = ConsistencyLevel.ANY - session.execute(statement) + @since 3.3 + @jira_ticket PYTHON-296 + @expected_result The client can connect via SSL and preform some basic operations, with check_hostname specified - cluster.shutdown() + @test_category connection:ssl + """ + + # find absolute path to client CA_CERTS + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + ssl_options = {'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl.PROTOCOL_TLSv1, + 'cert_reqs': ssl.CERT_REQUIRED, + 'check_hostname': True} + + validate_ssl_options(ssl_options=ssl_options) class SSLConnectionAuthTests(unittest.TestCase): @@ -158,40 +186,39 @@ class SSLConnectionAuthTests(unittest.TestCase): abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) abs_driver_keyfile = os.path.abspath(DRIVER_KEYFILE) abs_driver_certfile = os.path.abspath(DRIVER_CERTFILE) + ssl_options = {'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl.PROTOCOL_TLSv1, + 'keyfile': abs_driver_keyfile, + 'certfile': abs_driver_certfile} + validate_ssl_options(ssl_options) - tries = 0 - while True: - if tries > 5: - raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") - try: - cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options={'ca_certs': abs_path_ca_cert_path, - 'ssl_version': ssl.PROTOCOL_TLSv1, - 'keyfile': abs_driver_keyfile, - 'certfile': abs_driver_certfile}) + def test_can_connect_with_ssl_client_auth_host_name(self): + """ + Test to validate that we can connect to a C* cluster that has client_auth enabled, and hostmatching - session = cluster.connect() - break - except Exception: - ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) - del tb - tries += 1 + This test will setup and use a c* cluster that has client authentication enabled. It will then attempt + to connect using valid client keys, and certs (that are in the server's truststore), and attempt to preform some + basic operations, with check_hostname specified + @jira_ticket PYTHON-296 + @since 3.3 - # attempt a few simple commands. + @expected_result The client can connect via SSL and preform some basic operations - insert_keyspace = """CREATE KEYSPACE ssltest - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} - """ - statement = SimpleStatement(insert_keyspace) - statement.consistency_level = 3 - session.execute(statement) + @test_category connection:ssl + """ - drop_keyspace = "DROP KEYSPACE ssltest" - statement = SimpleStatement(drop_keyspace) - statement.consistency_level = ConsistencyLevel.ANY - session.execute(statement) + # Need to get absolute paths for certs/key + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + abs_driver_keyfile = os.path.abspath(DRIVER_KEYFILE) + abs_driver_certfile = os.path.abspath(DRIVER_CERTFILE) - cluster.shutdown() + ssl_options = {'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl.PROTOCOL_TLSv1, + 'keyfile': abs_driver_keyfile, + 'certfile': abs_driver_certfile, + 'cert_reqs': ssl.CERT_REQUIRED, + 'check_hostname': True} + validate_ssl_options(ssl_options) def test_cannot_connect_without_client_auth(self): """ @@ -213,6 +240,7 @@ class SSLConnectionAuthTests(unittest.TestCase): with self.assertRaises(NoHostAvailable) as context: cluster.connect() + cluster.shutdown() def test_cannot_connect_with_bad_client_auth(self): """ @@ -239,3 +267,4 @@ class SSLConnectionAuthTests(unittest.TestCase): 'certfile': abs_driver_certfile}) with self.assertRaises(NoHostAvailable) as context: cluster.connect() + cluster.shutdown() diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 003c78a9..d6650a47 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -27,11 +27,11 @@ from cassandra.cluster import Cluster, NoHostAvailable from cassandra.concurrent import execute_concurrent from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance, - WhiteListRoundRobinPolicy) + WhiteListRoundRobinPolicy, AddressTranslator) from cassandra.protocol import MAX_SUPPORTED_VERSION from cassandra.query import SimpleStatement, TraceUnavailable -from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, get_node, CASSANDRA_VERSION, execute_until_pass, execute_with_long_wait_retry +from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, get_node, CASSANDRA_VERSION, execute_until_pass, execute_with_long_wait_retry, BasicExistingKeyspaceUnitTestCase, get_node from tests.integration.util import assert_quiescent_pool_state @@ -41,6 +41,39 @@ def setup_module(): class ClusterTests(unittest.TestCase): + + def test_host_resolution(self): + """ + Test to insure A records are resolved appropriately. + + @since 3.3 + @jira_ticket PYTHON-415 + @expected_result hostname will be transformed into IP + + @test_category connection + """ + cluster = Cluster(contact_points=["localhost"], protocol_version=PROTOCOL_VERSION, connect_timeout=1) + self.assertTrue('127.0.0.1' in cluster.contact_points_resolved) + + def test_host_duplication(self): + """ + Ensure that duplicate hosts in the contact points are surfaced in the cluster metadata + + @since 3.3 + @jira_ticket PYTHON-103 + @expected_result duplicate hosts aren't surfaced in cluster.metadata + + @test_category connection + """ + cluster = Cluster(contact_points=["localhost", "127.0.0.1", "localhost", "localhost", "localhost"], protocol_version=PROTOCOL_VERSION, connect_timeout=1) + cluster.connect() + self.assertEqual(len(cluster.metadata.all_hosts()), 3) + cluster.shutdown() + cluster = Cluster(contact_points=["127.0.0.1", "localhost"], protocol_version=PROTOCOL_VERSION, connect_timeout=1) + cluster.connect() + self.assertEqual(len(cluster.metadata.all_hosts()), 3) + cluster.shutdown() + def test_raise_error_on_control_connection_timeout(self): """ Test for initial control connection timeout @@ -585,3 +618,54 @@ class ClusterTests(unittest.TestCase): cluster.shutdown() +class LocalHostAdressTranslator(AddressTranslator): + + def __init__(self, addr_map=None): + self.addr_map = addr_map + + def translate(self, addr): + new_addr = self.addr_map.get(addr) + return new_addr + + +class TestAddressTranslation(unittest.TestCase): + + def test_address_translator_basic(self): + """ + Test host address translation + + Uses a custom Address Translator to map all ip back to one. + Validates AddressTranslator invocation by ensuring that only meta data associated with single + host is populated + + @since 3.3 + @jira_ticket PYTHON-69 + @expected_result only one hosts' metadata will be populeated + + @test_category metadata + """ + lh_ad = LocalHostAdressTranslator({'127.0.0.1': '127.0.0.1', '127.0.0.2': '127.0.0.1', '127.0.0.3': '127.0.0.1'}) + c = Cluster(address_translator=lh_ad) + c.connect() + self.assertEqual(len(c.metadata.all_hosts()), 1) + c.shutdown() + + def test_address_translator_with_mixed_nodes(self): + """ + Test host address translation + + Uses a custom Address Translator to map ip's of non control_connection nodes to each other + Validates AddressTranslator invocation by ensuring that metadata for mapped hosts is also mapped + + @since 3.3 + @jira_ticket PYTHON-69 + @expected_result metadata for crossed hosts will also be crossed + + @test_category metadata + """ + adder_map = {'127.0.0.1': '127.0.0.1', '127.0.0.2': '127.0.0.3', '127.0.0.3': '127.0.0.2'} + lh_ad = LocalHostAdressTranslator(adder_map) + c = Cluster(address_translator=lh_ad) + c.connect() + for host in c.metadata.all_hosts(): + self.assertEqual(adder_map.get(str(host)), host.broadcast_address) diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 51dd11a7..2d07b920 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -24,12 +24,15 @@ from threading import Thread, Event import time from cassandra import ConsistencyLevel, OperationTimedOut -from cassandra.cluster import NoHostAvailable +from cassandra.cluster import NoHostAvailable, Cluster from cassandra.io.asyncorereactor import AsyncoreConnection from cassandra.protocol import QueryMessage +from cassandra.connection import Connection +from cassandra.policies import WhiteListRoundRobinPolicy, HostStateListener +from cassandra.pool import HostConnectionPool from tests import is_monkey_patched -from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration import use_singledc, PROTOCOL_VERSION, get_node try: from cassandra.io.libevreactor import LibevConnection @@ -41,6 +44,131 @@ def setup_module(): use_singledc() +class ConnectionTimeoutTest(unittest.TestCase): + + def setUp(self): + self.defaultInFlight = Connection.max_in_flight + Connection.max_in_flight = 2 + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) + self.session = self.cluster.connect() + + def tearDown(self): + Connection.max_in_flight = self.defaultInFlight + self.cluster.shutdown() + + def test_in_flight_timeout(self): + """ + Test to ensure that connection id fetching will block when max_id is reached/ + + In previous versions of the driver this test will cause a + NoHostAvailable exception to be thrown, when the max_id is restricted + + @since 3.3 + @jira_ticket PYTHON-514 + @expected_result When many requests are run on a single node connection acquisition should block + until connection is available or the request times out. + + @test_category connection timeout + """ + futures = [] + query = '''SELECT * FROM system.local''' + for i in range(100): + futures.append(self.session.execute_async(query)) + + for future in futures: + future.result() + + +class TestHostListener(HostStateListener): + host_down = None + + def on_down(self, host): + host_down = host + + +class HeartbeatTest(unittest.TestCase): + """ + Test to validate failing a heartbeat check doesn't mark a host as down + + @since 3.3 + @jira_ticket PYTHON-286 + @expected_result host should not be marked down when heartbeat fails + + @test_category connection heartbeat + """ + + def setUp(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=1) + self.session = self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + def test_heart_beat_timeout(self): + # Setup a host listener to ensure the nodes don't go down + test_listener = TestHostListener() + host = "127.0.0.1" + node = get_node(1) + initial_connections = self.fetch_connections(host, self.cluster) + self.assertNotEqual(len(initial_connections), 0) + self.cluster.register_listener(test_listener) + # Pause the node + node.pause() + # Wait for connections associated with this host go away + self.wait_for_no_connections(host, self.cluster) + # Resume paused node + node.resume() + # Run a query to ensure connections are re-established + current_host = "" + count = 0 + while current_host != host and count < 100: + rs = self.session.execute_async("SELECT * FROM system.local", trace=False) + rs.result() + current_host = str(rs._current_host) + count += 1 + time.sleep(.1) + self.assertLess(count, 100, "Never connected to the first node") + new_connections = self.wait_for_connections(host, self.cluster) + self.assertIsNone(test_listener.host_down) + # Make sure underlying new connections don't match previous ones + for connection in initial_connections: + self.assertFalse(connection in new_connections) + + def fetch_connections(self, host, cluster): + # Given a cluster object and host grab all connection associated with that host + connections = [] + holders = cluster.get_connection_holders() + for conn in holders: + if host == str(getattr(conn, 'host', '')): + if isinstance(conn, HostConnectionPool): + if conn._connections is not None: + connections.append(conn._connections) + else: + if conn._connection is not None: + connections.append(conn._connection) + return connections + + def wait_for_connections(self, host, cluster): + retry = 0 + while(retry < 300): + retry += 1 + connections = self.fetch_connections(host, cluster) + if len(connections) is not 0: + return connections + time.sleep(.1) + self.fail("No new connections found") + + def wait_for_no_connections(self, host, cluster): + retry = 0 + while(retry < 100): + retry += 1 + connections = self.fetch_connections(host, cluster) + if len(connections) is 0: + return + time.sleep(.5) + self.fail("Connections never cleared") + + class ConnectionTests(object): klass = None diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index dc24d0a3..783df261 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -7,11 +7,12 @@ try: except ImportError: import unittest +from cassandra import DriverException, Timeout, AlreadyExists from cassandra.query import tuple_factory -from cassandra.cluster import Cluster -from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler - -from tests.integration import use_singledc, PROTOCOL_VERSION, notprotocolv1, drop_keyspace_shutdown_cluster +from cassandra.cluster import Cluster, NoHostAvailable +from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler, ConfigurationException +from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY +from tests.integration import use_singledc, PROTOCOL_VERSION, notprotocolv1, drop_keyspace_shutdown_cluster, VERIFY_CYTHON, BasicSharedKeyspaceUnitTestCase, execute_with_retry_tolerant from tests.integration.datatype_utils import update_datatypes from tests.integration.standard.utils import ( create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes) @@ -123,6 +124,20 @@ class CythonProtocolHandlerTest(unittest.TestCase): cluster.shutdown() + @numpytest + def test_cython_numpy_are_installed_valid(self): + """ + Test to validate that cython and numpy are installed correctly + @since 3.3.0 + @jira_ticket PYTHON-543 + @expected_result Cython and Numpy should be present + + @test_category configuration + """ + if VERIFY_CYTHON: + self.assertTrue(HAVE_CYTHON) + self.assertTrue(HAVE_NUMPY) + def _verify_numpy_page(self, page): colnames = self.colnames datatypes = get_primitive_datatypes() @@ -188,3 +203,69 @@ def verify_iterator_data(assertEqual, results): for expected, actual in zip(params, result): assertEqual(actual, expected) return count + + +class NumpyNullTest(BasicSharedKeyspaceUnitTestCase): + + # A dictionary containing table key to type. + # Boolean dictates whether or not the type can be deserialized with null value + NUMPY_TYPES = {"v1": ('bigint', False), + "v2": ('double', False), + "v3": ('float', False), + "v4": ('int', False), + "v5": ('smallint', False), + "v6": ("ascii", True), + "v7": ("blob", True), + "v8": ("boolean", True), + "v9": ("decimal", True), + "v10": ("inet", True), + "v11": ("text", True), + "v12": ("timestamp", True), + "v13": ("timeuuid", True), + "v14": ("uuid", True), + "v15": ("varchar", True), + "v16": ("varint", True), + } + + def setUp(self): + self.session.client_protocol_handler = NumpyProtocolHandler + self.session.row_factory = tuple_factory + + @numpytest + def test_null_types3(self): + """ + Test to validate that the numpy protocol handler can deal with null values. + @since 3.3.0 + @jira_ticket PYTHON-550 + @expected_result Numpy can handle non mapped types' null values. + + @test_category data_types:serialization + """ + + self.create_table_of_types() + self.session.execute("INSERT INTO {0}.{1} (k) VALUES (1)".format(self.keyspace_name, self.function_table_name)) + self.validate_types() + + def create_table_of_types(self): + """ + Builds a table containing all the numpy types + """ + base_ddl = '''CREATE TABLE {0}.{1} (k int PRIMARY KEY'''.format(self.keyspace_name, self.function_table_name, type) + for key, value in NumpyNullTest.NUMPY_TYPES.items(): + base_ddl = base_ddl+", {0} {1}".format(key, value[0]) + base_ddl = base_ddl+")" + execute_with_retry_tolerant(self.session, base_ddl, (DriverException, NoHostAvailable, Timeout), (ConfigurationException, AlreadyExists)) + + def validate_types(self): + """ + Selects each type from the table and expects either an exception or None depending on type + """ + for key, value in NumpyNullTest.NUMPY_TYPES.items(): + select = "SELECT {0} from {1}.{2}".format(key,self.keyspace_name, self.function_table_name) + if value[1]: + rs = execute_with_retry_tolerant(self.session, select, (NoHostAvailable), ()) + self.assertEqual(rs[0].get('v1'), None) + else: + with self.assertRaises(ValueError): + execute_with_retry_tolerant(self.session, select, (NoHostAvailable), ()) + diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index 583943bc..4d86723f 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -34,18 +34,84 @@ from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host from tests.integration import get_cluster, use_singledc, PROTOCOL_VERSION, get_server_versions, execute_until_pass, \ - BasicSegregatedKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster + BasicSegregatedKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase, BasicExistingKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster, CASSANDRA_VERSION from tests.unit.cython.utils import notcython + def setup_module(): use_singledc() global CASS_SERVER_VERSION CASS_SERVER_VERSION = get_server_versions()[0] +class HostMetatDataTests(BasicExistingKeyspaceUnitTestCase): + def test_broadcast_listen_address(self): + """ + Check to ensure that the broadcast and listen adresss is populated correctly + + @since 3.3 + @jira_ticket PYTHON-332 + @expected_result They are populated for C*> 2.0.16, 2.1.6, 2.2.0 + + @test_category metadata + """ + # All nodes should have the broadcast_address set + for host in self.cluster.metadata.all_hosts(): + self.assertIsNotNone(host.broadcast_address) + con = self.cluster.control_connection.get_connections()[0] + local_host = con.host + # The control connection node should have the listen address set. + listen_addrs = [host.listen_address for host in self.cluster.metadata.all_hosts()] + self.assertTrue(local_host in listen_addrs) + + def test_host_release_version(self): + """ + Checks the hosts release version and validates that it is equal to the + Cassandra version we are using in our test harness. + + @since 3.3 + @jira_ticket PYTHON-301 + @expected_result host.release version should match our specified Cassandra version. + + @test_category metadata + """ + for host in self.cluster.metadata.all_hosts(): + self.assertTrue(host.release_version.startswith(CASSANDRA_VERSION)) + + class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase): + def test_schema_metadata_disable(self): + """ + Checks to ensure that schema metadata_enabled, and token_metadata_enabled + flags work correctly. + + @since 3.3 + @jira_ticket PYTHON-327 + @expected_result schema metadata will not be populated when schema_metadata_enabled is fause + token_metadata will be missing when token_metadata is set to false + + @test_category metadata + """ + # Validate metadata is missing where appropriate + no_schema = Cluster(schema_metadata_enabled=False) + no_schema_session = no_schema.connect() + self.assertEqual(len(no_schema.metadata.keyspaces), 0) + self.assertEqual(no_schema.metadata.export_schema_as_string(), '') + no_token = Cluster(token_metadata_enabled=False) + no_token_session = no_token.connect() + self.assertEqual(len(no_token.metadata.token_map.token_to_host_owner), 0) + + # Do a simple query to ensure queries are working + query = "SELECT * FROM system.local" + no_schema_rs = no_schema_session.execute(query) + no_token_rs = no_token_session.execute(query) + self.assertIsNotNone(no_schema_rs[0]) + self.assertIsNotNone(no_token_rs[0]) + no_schema.shutdown() + no_token.shutdown() + def make_create_statement(self, partition_cols, clustering_cols=None, other_cols=None, compact=False): clustering_cols = clustering_cols or [] other_cols = other_cols or [] @@ -120,7 +186,8 @@ class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase): self.assertEqual([], tablemeta.clustering_key) self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) - parser = get_schema_parser(self.cluster.control_connection._connection, 1) + cc = self.cluster.control_connection._connection + parser = get_schema_parser(cc, str(CASS_SERVER_VERSION[0]), 1) for option in tablemeta.options: self.assertIn(option, parser.recognized_table_options) @@ -1067,14 +1134,14 @@ Approximate structure, for reference: CREATE TABLE legacy.composite_comp_with_col ( key blob, - t timeuuid, b blob, s text, + t timeuuid, "b@6869746d65776974686d75736963" blob, "b@6d616d6d616a616d6d61" blob, - PRIMARY KEY (key, t, b, s) + PRIMARY KEY (key, b, s, t) ) WITH COMPACT STORAGE - AND CLUSTERING ORDER BY (t ASC, b ASC, s ASC) + AND CLUSTERING ORDER BY (b ASC, s ASC, t ASC) AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' AND comment = 'Stores file meta data' AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'} @@ -1195,8 +1262,8 @@ Approximate structure, for reference: CREATE TABLE legacy.composite_comp_no_col ( key blob, - column1 'org.apache.cassandra.db.marshal.DynamicCompositeType(org.apache.cassandra.db.marshal.TimeUUIDType, org.apache.cassandra.db.marshal.BytesType, org.apache.cassandra.db.marshal.UTF8Type)', - column2 text, + column1 'org.apache.cassandra.db.marshal.DynamicCompositeType(org.apache.cassandra.db.marshal.BytesType, org.apache.cassandra.db.marshal.UTF8Type, org.apache.cassandra.db.marshal.TimeUUIDType)', + column2 timeuuid, value blob, PRIMARY KEY (key, column1, column1, column2) ) WITH COMPACT STORAGE @@ -1897,7 +1964,7 @@ class BadMetaTest(unittest.TestCase): cls.session.execute("CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) cls.session.set_keyspace(cls.keyspace_name) connection = cls.cluster.control_connection._connection - cls.parser_class = get_schema_parser(connection, timeout=20).__class__ + cls.parser_class = get_schema_parser(connection, str(CASS_SERVER_VERSION[0]), timeout=20).__class__ @classmethod def teardown_class(cls): diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 48627b3f..13758b65 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -145,13 +145,13 @@ class MetricsTests(unittest.TestCase): query = SimpleStatement("INSERT INTO test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) with self.assertRaises(Unavailable): self.session.execute(query) - self.assertEqual(1, self.cluster.metrics.stats.unavailables) + self.assertEqual(2, self.cluster.metrics.stats.unavailables) # Test write query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) with self.assertRaises(Unavailable): self.session.execute(query, timeout=None) - self.assertEqual(2, self.cluster.metrics.stats.unavailables) + self.assertEqual(4, self.cluster.metrics.stats.unavailables) finally: get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) # Give some time for the cluster to come back up, for the next test diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 51ab1f51..dff9715a 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -20,13 +20,13 @@ try: except ImportError: import unittest # noqa -from cassandra import ConsistencyLevel +from cassandra import ConsistencyLevel, Unavailable, InvalidRequest from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement, BatchStatement, BatchType, dict_factory, TraceUnavailable) -from cassandra.cluster import Cluster +from cassandra.cluster import Cluster, NoHostAvailable from cassandra.policies import HostDistance -from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions, greaterthanprotocolv3 +from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions, greaterthanprotocolv3, get_node import time import re @@ -109,7 +109,7 @@ class QueryTests(BasicSharedKeyspaceUnitTestCase): only be the case if the c* version is 2.2 or greater @since 2.6.0 - @jira_ticket PYTHON-235 + @jira_ticket PYTHON-435 @expected_result client address should be present in C* >= 2.2, otherwise should be none. @test_category tracing @@ -138,6 +138,32 @@ class QueryTests(BasicSharedKeyspaceUnitTestCase): self.assertIsNotNone(client_ip, "Client IP was not set in trace with C* >= 2.2") self.assertTrue(pat.match(client_ip), "Client IP from trace did not match the expected value") + def test_trace_cl(self): + """ + Test to ensure that CL is set correctly honored when executing trace queries. + + @since 3.3 + @jira_ticket PYTHON-435 + @expected_result Consistency Levels set on get_query_trace should be honored + """ + # Execute a query + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + response_future = self.session.execute_async(statement, trace=True) + response_future.result() + with self.assertRaises(Unavailable): + response_future.get_query_trace(query_cl=ConsistencyLevel.THREE) + # Try again with a smattering of other CL's + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.TWO).trace_id) + response_future = self.session.execute_async(statement, trace=True) + response_future.result() + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ONE).trace_id) + response_future = self.session.execute_async(statement, trace=True) + response_future.result() + with self.assertRaises(InvalidRequest): + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ANY).trace_id) + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.QUORUM).trace_id) + def test_incomplete_query_trace(self): """ Tests to ensure that partial tracing works. diff --git a/tests/integration/standard/test_row_factories.py b/tests/integration/standard/test_row_factories.py index a1864648..df9765ee 100644 --- a/tests/integration/standard/test_row_factories.py +++ b/tests/integration/standard/test_row_factories.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCaseWFunctionTable +from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCaseWFunctionTable, BasicSharedKeyspaceUnitTestCase, execute_until_pass try: import unittest2 as unittest @@ -28,6 +28,52 @@ def setup_module(): use_singledc() +class NameTupleFactory(BasicSharedKeyspaceUnitTestCase): + + def setUp(self): + super(NameTupleFactory, self).setUp() + self.session.row_factory = named_tuple_factory + ddl = ''' + CREATE TABLE {0}.{1} ( + k int PRIMARY KEY, + v1 text, + v2 text, + v3 text)'''.format(self.ks_name, self.function_table_name) + self.session.execute(ddl) + execute_until_pass(self.session, ddl) + + def test_sanitizing(self): + """ + Test to ensure that same named results are surfaced in the NamedTupleFactory + + Creates a table with a few different text fields. Inserts a few values in that table. + It then fetches the values and confirms that despite all be being selected as the same name + they are propagated in the result set differently. + + @since 3.3 + @jira_ticket PYTHON-467 + @expected_result duplicate named results have unique row names. + + @test_category queries + """ + + for x in range(5): + insert1 = ''' + INSERT INTO {0}.{1} + ( k , v1, v2, v3 ) + VALUES + ( 1 , 'v1{2}', 'v2{2}','v3{2}' ) + '''.format(self.keyspace_name, self.function_table_name, str(x)) + self.session.execute(insert1) + + query = "SELECT v1 AS duplicate, v2 AS duplicate, v3 AS duplicate from {0}.{1}".format(self.ks_name, self.function_table_name) + rs = self.session.execute(query) + row = rs[0] + self.assertTrue(hasattr(row, 'duplicate')) + self.assertTrue(hasattr(row, 'duplicate_')) + self.assertTrue(hasattr(row, 'duplicate__')) + + class RowFactoryTests(BasicSharedKeyspaceUnitTestCaseWFunctionTable): """ Test different row_factories and access code diff --git a/tests/integration/standard/utils.py b/tests/integration/standard/utils.py index 8749c5e3..4011047f 100644 --- a/tests/integration/standard/utils.py +++ b/tests/integration/standard/utils.py @@ -45,7 +45,11 @@ def get_all_primitive_params(key): """ params = [key] for datatype in PRIMITIVE_DATATYPES: - params.append(get_sample(datatype)) + # Also test for empty strings + if key == 1 and datatype == 'ascii': + params.append('') + else: + params.append(get_sample(datatype)) return params diff --git a/tests/unit/cython/utils.py b/tests/unit/cython/utils.py index 6dbeb15c..916de137 100644 --- a/tests/unit/cython/utils.py +++ b/tests/unit/cython/utils.py @@ -13,6 +13,7 @@ # limitations under the License. from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY +from tests.integration import VERIFY_CYTHON try: import unittest2 as unittest @@ -34,6 +35,6 @@ def cyimport(import_path): # @cythontest # def test_something(self): ... -cythontest = unittest.skipUnless(HAVE_CYTHON, 'Cython is not available') +cythontest = unittest.skipUnless((HAVE_CYTHON or VERIFY_CYTHON) or VERIFY_CYTHON, 'Cython is not available') notcython = unittest.skipIf(HAVE_CYTHON, 'Cython not supported') -numpytest = unittest.skipUnless(HAVE_CYTHON and HAVE_NUMPY, 'NumPy is not available') +numpytest = unittest.skipUnless((HAVE_CYTHON and HAVE_NUMPY) or VERIFY_CYTHON, 'NumPy is not available') diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 763875c9..3c4f477f 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -19,11 +19,75 @@ except ImportError: from mock import patch, Mock -from cassandra import ConsistencyLevel -from cassandra.cluster import _Scheduler, Session +from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ + InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException +from cassandra.cluster import _Scheduler, Session, Cluster from cassandra.query import SimpleStatement +class ExceptionTypeTest(unittest.TestCase): + + def test_exception_types(self): + """ + PYTHON-443 + Sanity check to ensure we don't unintentionally change class hierarchy of exception types + """ + self.assertTrue(issubclass(Unavailable, DriverException)) + self.assertTrue(issubclass(Unavailable, RequestExecutionException)) + + self.assertTrue(issubclass(ReadTimeout, DriverException)) + self.assertTrue(issubclass(ReadTimeout, RequestExecutionException)) + self.assertTrue(issubclass(ReadTimeout, Timeout)) + + self.assertTrue(issubclass(WriteTimeout, DriverException)) + self.assertTrue(issubclass(WriteTimeout, RequestExecutionException)) + self.assertTrue(issubclass(WriteTimeout, Timeout)) + + self.assertTrue(issubclass(CoordinationFailure, DriverException)) + self.assertTrue(issubclass(CoordinationFailure, RequestExecutionException)) + + self.assertTrue(issubclass(ReadFailure, DriverException)) + self.assertTrue(issubclass(ReadFailure, RequestExecutionException)) + self.assertTrue(issubclass(ReadFailure, CoordinationFailure)) + + self.assertTrue(issubclass(WriteFailure, DriverException)) + self.assertTrue(issubclass(WriteFailure, RequestExecutionException)) + self.assertTrue(issubclass(WriteFailure, CoordinationFailure)) + + self.assertTrue(issubclass(FunctionFailure, DriverException)) + self.assertTrue(issubclass(FunctionFailure, RequestExecutionException)) + + self.assertTrue(issubclass(RequestValidationException, DriverException)) + + self.assertTrue(issubclass(ConfigurationException, DriverException)) + self.assertTrue(issubclass(ConfigurationException, RequestValidationException)) + + self.assertTrue(issubclass(AlreadyExists, DriverException)) + self.assertTrue(issubclass(AlreadyExists, RequestValidationException)) + self.assertTrue(issubclass(AlreadyExists, ConfigurationException)) + + self.assertTrue(issubclass(InvalidRequest, DriverException)) + self.assertTrue(issubclass(InvalidRequest, RequestValidationException)) + + self.assertTrue(issubclass(Unauthorized, DriverException)) + self.assertTrue(issubclass(Unauthorized, RequestValidationException)) + + self.assertTrue(issubclass(AuthenticationFailed, DriverException)) + + self.assertTrue(issubclass(OperationTimedOut, DriverException)) + + self.assertTrue(issubclass(UnsupportedOperation, DriverException)) + + +class ContactListTest(unittest.TestCase): + + def test_invalid_types(self, *args): + with self.assertRaises(ValueError): + Cluster(contact_points=[None], protocol_version=4, connect_timeout=1) + with self.assertRaises(TypeError): + Cluster(contact_points="not a sequence", protocol_version=4, connect_timeout=1) + + class SchedulerTest(unittest.TestCase): # TODO: this suite could be expanded; for now just adding a test covering a ticket @@ -35,9 +99,6 @@ class SchedulerTest(unittest.TestCase): PYTHON-473 """ - sched = _Scheduler(None) - sched.schedule(0, lambda: None) - sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()"t class SessionTest(unittest.TestCase): diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 15fa6e72..2ac10a59 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -22,9 +22,11 @@ from six import BytesIO import time from threading import Lock -from cassandra.cluster import Cluster, Session +from cassandra import OperationTimedOut +from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, - locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager) + locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, + ConnectionException) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler) @@ -344,7 +346,7 @@ class ConnectionHeartbeatTest(unittest.TestCase): get_holders = self.make_get_holders(1) max_connection = Mock(spec=Connection, host='localhost', lock=Lock(), - max_request_id=in_flight, in_flight=in_flight, + max_request_id=in_flight - 1, in_flight=in_flight, is_idle=True, is_defunct=False, is_closed=False) holder = get_holders.return_value[0] holder.get_connections.return_value.append(max_connection) @@ -382,8 +384,8 @@ class ConnectionHeartbeatTest(unittest.TestCase): connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) exc = connection.defunct.call_args_list[0][0][0] - self.assertIsInstance(exc, Exception) - self.assertEqual(exc.args, Exception('Connection heartbeat failure').args) + self.assertIsInstance(exc, ConnectionException) + self.assertRegexpMatches(exc.args[0], r'^Received unexpected response to OptionsMessage.*') holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count) def test_timeout(self, *args): @@ -410,8 +412,9 @@ class ConnectionHeartbeatTest(unittest.TestCase): connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) exc = connection.defunct.call_args_list[0][0][0] - self.assertIsInstance(exc, Exception) - self.assertEqual(exc.args, Exception('Connection heartbeat failure').args) + self.assertIsInstance(exc, OperationTimedOut) + self.assertEqual(exc.errors, 'Connection heartbeat timeout after 0.05 seconds') + self.assertEqual(exc.last_host, 'localhost') holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count) diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 9fac7e2f..e109a76f 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -25,7 +25,7 @@ from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS from cassandra.cluster import ControlConnection, _Scheduler from cassandra.pool import Host from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, - ConstantReconnectionPolicy) + ConstantReconnectionPolicy, IdentityTranslator) PEER_IP = "foobar" @@ -61,6 +61,7 @@ class MockCluster(object): max_schema_agreement_wait = 5 load_balancing_policy = RoundRobinPolicy() reconnection_policy = ConstantReconnectionPolicy(2) + address_translator = IdentityTranslator() down_host = None contact_points = [] is_shutdown = False diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index c2e3513f..0b86c487 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -380,9 +380,7 @@ class IndexTest(unittest.TestCase): column_meta.table.name = 'table_name_here' column_meta.table.keyspace_name = 'keyspace_name_here' column_meta.table.columns = {column_meta.name: column_meta} - connection = Mock() - connection.server_version = '2.1.0' - parser = get_schema_parser(connection, 0.1) + parser = get_schema_parser(Mock(), '2.1.0', 0.1) row = {'index_name': 'index_name_here', 'index_type': 'index_type_here'} index_meta = parser._build_index_metadata(column_meta, row) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index a9406cf7..6640ccf1 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -483,7 +483,7 @@ class DCAwareRoundRobinPolicyTest(unittest.TestCase): host_none = Host(1, SimpleConvictionPolicy) # contact point is '1' - cluster = Mock(contact_points=[1]) + cluster = Mock(contact_points_resolved=[1]) # contact DC first policy = DCAwareRoundRobinPolicy() @@ -916,14 +916,14 @@ class RetryPolicyTest(unittest.TestCase): retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=1, alive_replicas=2, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + self.assertEqual(retry, RetryPolicy.RETRY_NEXT_HOST) + self.assertEqual(consistency, ONE) retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=10000, alive_replicas=1, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + self.assertEqual(retry, RetryPolicy.RETRY_NEXT_HOST) + self.assertEqual(consistency, ONE) class FallthroughRetryPolicyTest(unittest.TestCase):