diff --git a/.gitignore b/.gitignore index 29785e79..d24b9945 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ nosetests.xml cover/ docs/_build/ tests/integration/ccm +setuptools*.tar.gz diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3875521a..dccd925a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,38 @@ +1.0.2 +===== +In Progress + +Bug Fixes +--------- +* With asyncorereactor, correctly handle EAGAIN/EWOULDBLOCK when the message from + Cassandra is a multiple of the read buffer size. Previously, if no more data + became available to read on the socket, the message would never be processed, + resulting in an OperationTimedOut error. +* Double quote keyspace, table and column names that require them (those using + uppercase characters or keywords) when generating CREATE statements through + KeyspaceMetadata and TableMetadata. +* Decode TimestampType as DateType. (Cassandra replaced DateType with + TimestampType to fix sorting of pre-unix epoch dates in CASSANDRA-5723.) +* Handle latest table options when parsing the schema and generating + CREATE statements. +* Avoid 'Set changed size during iteration' during query plan generation + when hosts go up or down + +Other +----- +* Remove ignored ``tracing_enabled`` parameter for ``SimpleStatement``. The + correct way to trace a query is by setting the ``trace`` argument to ``True`` + in ``Session.execute()`` and ``Session.execute_async()``. +* Raise TypeError instead of cassandra.query.InvalidParameterTypeError when + a parameter for a prepared statement has the wrong type; remove + cassandra.query.InvalidParameterTypeError. +* More consistent type checking for query parameters +* Add option to a return special object for empty string values for non-string + columns + 1.0.1 ===== -(In Progress) +Feb 19, 2014 Bug Fixes --------- @@ -9,12 +41,24 @@ Bug Fixes * Always close socket when defuncting error'ed connections to avoid a potential file descriptor leak * Handle "custom" types (such as the replaced DateType) correctly +* With libevreactor, correctly handle EAGAIN/EWOULDBLOCK when the message from + Cassandra is a multiple of the read buffer size. Previously, if no more data + became available to read on the socket, the message would never be processed, + resulting in an OperationTimedOut error. +* Don't break tracing when a Session's row_factory is not the default + namedtuple_factory. +* Handle data that is already utf8-encoded for UTF8Type values +* Fix token-aware routing for tokens that fall before the first node token in + the ring and tokens that exactly match a node's token +* Tolerate null source_elapsed values for Trace events. These may not be + set when events complete after the main operation has already completed. Other ----- * Skip sending OPTIONS message on connection creation if compression is disabled or not available and a CQL version has not been explicitly set +* Add details about errors and the last queried host to ``OperationTimedOut`` 1.0.0 Final =========== diff --git a/README-dev.rst b/README-dev.rst index 17902bcd..3495eb7c 100644 --- a/README-dev.rst +++ b/README-dev.rst @@ -16,12 +16,6 @@ Releasing so that it looks like ``(x, y, z, 'post')`` * Commit and push -Running the Tests -================= -In order for the extensions to be built and used in the test, run:: - - python setup.py nosetests - Building the Docs ================= Sphinx is required to build the docs. You probably want to install through apt, @@ -48,3 +42,26 @@ For example:: cp -R docs/_build/1.0.0-beta1/* ~/python-driver-docs/ cd ~/python-driver-docs git push origin gh-pages + +Running the Tests +================= +In order for the extensions to be built and used in the test, run:: + + python setup.py nosetests + +You can run a specific test module or package like so:: + + python setup.py nosetests -w tests/unit/ + +If you want to test all of python 2.6, 2.7, and pypy, use tox (this is what +TravisCI runs):: + + tox + +By default, tox only runs the unit tests because I haven't put in the effort +to get the integration tests to run on TravicCI. However, the integration +tests should work locally. To run them, edit the following line in tox.ini:: + + commands = {envpython} setup.py build_ext --inplace nosetests --verbosity=2 tests/unit/ + +and change ``tests/unit/`` to ``tests/``. diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 8b4ea796..bec6e759 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -9,7 +9,7 @@ class NullHandler(logging.Handler): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (1, 0, 0, 'post') +__version_info__ = (1, 0, 1, 'post') __version__ = '.'.join(map(str, __version_info__)) @@ -226,4 +226,19 @@ class OperationTimedOut(Exception): to complete. This is not an error generated by Cassandra, only the driver. """ - pass + + errors = None + """ + A dict of errors keyed by the :class:`~.Host` against which they occurred. + """ + + last_host = None + """ + The last :class:`~.Host` this operation was attempted against. + """ + + def __init__(self, errors=None, last_host=None): + self.errors = errors + self.last_host = last_host + message = "errors=%s, last_host=%s" % (self.errors, self.last_host) + Exception.__init__(self, message) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 95dc3a82..ac2c9fec 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -858,8 +858,8 @@ class Cluster(object): "statement on host %s: %r", host, response) log.debug("Done preparing all known prepared statements against host %s", host) - except OperationTimedOut: - log.warn("Timed out trying to prepare all statements on host %s", host) + except OperationTimedOut as timeout: + log.warn("Timed out trying to prepare all statements on host %s: %s", host, timeout) except (ConnectionException, socket.error) as exc: log.warn("Error trying to prepare all statements on host %s: %r", host, exc) except Exception: @@ -1045,6 +1045,12 @@ class Session(object): ... log.exception("Operation failed:") """ + future = self._create_response_future(query, parameters, trace) + future.send_request() + return future + + def _create_response_future(self, query, parameters, trace): + """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None if isinstance(query, basestring): query = SimpleStatement(query) @@ -1066,11 +1072,9 @@ class Session(object): if trace: message.tracing = True - future = ResponseFuture( + return ResponseFuture( self, message, query, self.default_timeout, metrics=self._metrics, prepared_statement=prepared_statement) - future.send_request() - return future def prepare(self, query): """ @@ -1650,8 +1654,9 @@ class ControlConnection(object): timeout = min(2.0, total_timeout - elapsed) peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=timeout) - except OperationTimedOut: - log.debug("[control connection] Timed out waiting for response during schema agreement check") + except OperationTimedOut as timeout: + log.debug("[control connection] Timed out waiting for " \ + "response during schema agreement check: %s", timeout) elapsed = self._time.time() - start continue @@ -2195,7 +2200,7 @@ class ResponseFuture(object): elif self._final_exception: raise self._final_exception else: - raise OperationTimedOut() + raise OperationTimedOut(errors=self._errors, last_host=self._current_host) def get_query_trace(self, max_wait=None): """ diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index da8eb95e..7c6f5550 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -152,12 +152,35 @@ def lookup_casstype(casstype): raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e)) +class EmptyValue(object): + """ See _CassandraType.support_empty_values """ + + def __str__(self): + return "EMPTY" + __repr__ = __str__ + +EMPTY = EmptyValue() + + class _CassandraType(object): __metaclass__ = CassandraTypeType subtypes = () num_subtypes = 0 empty_binary_ok = False + support_empty_values = False + """ + Back in the Thrift days, empty strings were used for "null" values of + all types, including non-string types. For most users, an empty + string value in an int column is the same as being null/not present, + so the driver normally returns None in this case. (For string-like + types, it *will* return an empty string by default instead of None.) + + To avoid this behavior, set this to :const:`True`. Instead of returning + None for empty string values, the EMPTY singleton (an instance + of EmptyValue) will be returned. + """ + def __init__(self, val): self.val = self.validate(val) @@ -181,8 +204,10 @@ class _CassandraType(object): for more information. This method differs in that if None or the empty string is passed in, None may be returned. """ - if byts is None or (byts == '' and not cls.empty_binary_ok): + if byts is None: return None + elif byts == '' and not cls.empty_binary_ok: + return EMPTY if cls.support_empty_values else None return cls.deserialize(byts) @classmethod @@ -315,7 +340,10 @@ class DecimalType(_CassandraType): @staticmethod def serialize(dec): - sign, digits, exponent = dec.as_tuple() + try: + sign, digits, exponent = dec.as_tuple() + except AttributeError: + raise TypeError("Non-Decimal type received for Decimal value") unscaled = int(''.join([str(digit) for digit in digits])) if sign: unscaled *= -1 @@ -333,7 +361,10 @@ class UUIDType(_CassandraType): @staticmethod def serialize(uuid): - return uuid.bytes + try: + return uuid.bytes + except AttributeError: + raise TypeError("Got a non-UUID object for a UUID value") class BooleanType(_CassandraType): @@ -349,7 +380,7 @@ class BooleanType(_CassandraType): @staticmethod def serialize(truth): - return int8_pack(bool(truth)) + return int8_pack(truth) class AsciiType(_CassandraType): @@ -503,6 +534,10 @@ class DateType(_CassandraType): return int64_pack(long(converted)) +class TimestampType(DateType): + pass + + class TimeUUIDType(DateType): typename = 'timeuuid' @@ -515,7 +550,10 @@ class TimeUUIDType(DateType): @staticmethod def serialize(timeuuid): - return timeuuid.bytes + try: + return timeuuid.bytes + except AttributeError: + raise TypeError("Got a non-UUID object for a UUID value") class UTF8Type(_CassandraType): @@ -528,7 +566,11 @@ class UTF8Type(_CassandraType): @staticmethod def serialize(ustr): - return ustr.encode('utf8') + try: + return ustr.encode('utf-8') + except UnicodeDecodeError: + # already utf-8 + return ustr class VarcharType(UTF8Type): @@ -578,6 +620,9 @@ class _SimpleParameterizedType(_ParameterizedType): @classmethod def serialize_safe(cls, items): + if isinstance(items, basestring): + raise TypeError("Received a string for a type that expects a sequence") + subtype, = cls.subtypes buf = StringIO() buf.write(uint16_pack(len(items))) @@ -634,7 +679,11 @@ class MapType(_ParameterizedType): subkeytype, subvaltype = cls.subtypes buf = StringIO() buf.write(uint16_pack(len(themap))) - for key, val in themap.iteritems(): + try: + items = themap.iteritems() + except AttributeError: + raise TypeError("Got a non-map object for a map value") + for key, val in items: keybytes = subkeytype.to_binary(key) valbytes = subvaltype.to_binary(val) buf.write(uint16_pack(len(keybytes))) diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index 9c90abf4..a54aeeb9 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -260,7 +260,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): except socket.error as err: if err.args[0] not in NONBLOCKING: self.defunct(err) - return + return if self._iobuf.tell(): while True: diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index bef56cab..76b9be83 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -313,7 +313,7 @@ class LibevConnection(Connection): except socket.error as err: if err.args[0] not in NONBLOCKING: self.defunct(err) - return + return if self._iobuf.tell(): while True: diff --git a/cassandra/metadata.py b/cassandra/metadata.py index a87387f1..f0b5777a 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -1,4 +1,4 @@ -from bisect import bisect_left +from bisect import bisect_right from collections import defaultdict try: from collections import OrderedDict @@ -274,13 +274,13 @@ class Metadata(object): column_meta = self._build_column_metadata(table_meta, col_row) table_meta.columns[column_meta.name] = column_meta - table_meta.options = self._build_table_options(row, is_compact) + table_meta.options = self._build_table_options(row) + table_meta.is_compact_storage = is_compact return table_meta - def _build_table_options(self, row, is_compact_storage): + def _build_table_options(self, row): """ Setup the mostly-non-schema table options, like caching settings """ - options = dict((o, row.get(o)) for o in TableMetadata.recognized_options) - options["is_compact_storage"] = is_compact_storage + options = dict((o, row.get(o)) for o in TableMetadata.recognized_options if o in row) return options def _build_column_metadata(self, table_metadata, row): @@ -564,9 +564,10 @@ class KeyspaceMetadata(object): return "\n".join([self.as_cql_query()] + [t.export_as_string() for t in self.tables.values()]) def as_cql_query(self): - ret = "CREATE KEYSPACE %s WITH REPLICATION = %s " % \ - (self.name, self.replication_strategy.export_for_schema()) - return ret + (' AND DURABLE_WRITES = %s;' % ("true" if self.durable_writes else "false")) + ret = "CREATE KEYSPACE %s WITH replication = %s " % ( + protect_name(self.name), + self.replication_strategy.export_for_schema()) + return ret + (' AND durable_writes = %s;' % ("true" if self.durable_writes else "false")) class TableMetadata(object): @@ -610,6 +611,8 @@ class TableMetadata(object): A dict mapping column names to :class:`.ColumnMetadata` instances. """ + is_compact_storage = False + options = None """ A dict mapping table option names to their specific settings for this @@ -617,11 +620,28 @@ class TableMetadata(object): """ recognized_options = ( - "comment", "read_repair_chance", # "local_read_repair_chance", - "replicate_on_write", "gc_grace_seconds", "bloom_filter_fp_chance", - "caching", "compaction_strategy_class", "compaction_strategy_options", - "min_compaction_threshold", "max_compression_threshold", - "compression_parameters") + "comment", + "read_repair_chance", + "dclocal_read_repair_chance", + "replicate_on_write", + "gc_grace_seconds", + "bloom_filter_fp_chance", + "caching", + "compaction_strategy_class", + "compaction_strategy_options", + "min_compaction_threshold", + "max_compression_threshold", + "compression_parameters", + "min_index_interval", + "max_index_interval", + "index_interval", + "speculative_retry", + "rows_per_partition_to_cache", + "memtable_flush_period_in_ms", + "populate_io_cache_on_flush", + "compaction", + "compression", + "default_time_to_live") def __init__(self, keyspace_metadata, name, partition_key=None, clustering_key=None, columns=None, options=None): self.keyspace = keyspace_metadata @@ -653,7 +673,10 @@ class TableMetadata(object): creations are not included). If `formatted` is set to :const:`True`, extra whitespace will be added to make the query human readable. """ - ret = "CREATE TABLE %s.%s (%s" % (self.keyspace.name, self.name, "\n" if formatted else "") + ret = "CREATE TABLE %s.%s (%s" % ( + protect_name(self.keyspace.name), + protect_name(self.name), + "\n" if formatted else "") if formatted: column_join = ",\n" @@ -664,7 +687,7 @@ class TableMetadata(object): columns = [] for col in self.columns.values(): - columns.append("%s %s" % (col.name, col.typestring)) + columns.append("%s %s" % (protect_name(col.name), col.typestring)) if len(self.partition_key) == 1 and not self.clustering_key: columns[0] += " PRIMARY KEY" @@ -676,12 +699,12 @@ class TableMetadata(object): ret += "%s%sPRIMARY KEY (" % (column_join, padding) if len(self.partition_key) > 1: - ret += "(%s)" % ", ".join(col.name for col in self.partition_key) + ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key) else: ret += self.partition_key[0].name if self.clustering_key: - ret += ", %s" % ", ".join(col.name for col in self.clustering_key) + ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key) ret += ")" @@ -689,15 +712,15 @@ class TableMetadata(object): ret += "%s) WITH " % ("\n" if formatted else "") option_strings = [] - if self.options.get("is_compact_storage"): + if self.is_compact_storage: option_strings.append("COMPACT STORAGE") if self.clustering_key: cluster_str = "CLUSTERING ORDER BY " - clustering_names = self.protect_names([c.name for c in self.clustering_key]) + clustering_names = protect_names([c.name for c in self.clustering_key]) - if self.options.get("is_compact_storage") and \ + if self.is_compact_storage and \ not issubclass(self.comparator, types.CompositeType): subtypes = [self.comparator] else: @@ -711,52 +734,61 @@ class TableMetadata(object): cluster_str += "(%s)" % ", ".join(inner) option_strings.append(cluster_str) - option_strings.extend(map(self._make_option_str, self.recognized_options)) - option_strings = filter(lambda x: x is not None, option_strings) + option_strings.extend(self._make_option_strings()) join_str = "\n AND " if formatted else " AND " ret += join_str.join(option_strings) return ret - def _make_option_str(self, name): - value = self.options.get(name) - if value is not None: - if name == "comment": - value = value or "" - return "%s = %s" % (name, self.protect_value(value)) + def _make_option_strings(self): + ret = [] + for name, value in sorted(self.options.items()): + if value is not None: + if name == "comment": + value = value or "" + ret.append("%s = %s" % (name, protect_value(value))) - def protect_name(self, name): - if isinstance(name, unicode): - name = name.encode('utf8') - return self.maybe_escape_name(name) + return ret - def protect_names(self, names): - return map(self.protect_name, names) - def protect_value(self, value): - if value is None: - return 'NULL' - if isinstance(value, (int, float, bool)): - return str(value) - return "'%s'" % value.replace("'", "''") +def protect_name(name): + if isinstance(name, unicode): + name = name.encode('utf8') + return maybe_escape_name(name) - valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$') - def is_valid_name(self, name): - if name is None: - return False - if name.lower() in _keywords - _unreserved_keywords: - return False - return self.valid_cql3_word_re.match(name) is not None +def protect_names(names): + return map(protect_name, names) - def maybe_escape_name(self, name): - if self.is_valid_name(name): - return name - return self.escape_name(name) - def escape_name(self, name): - return '"%s"' % (name.replace('"', '""'),) +def protect_value(value): + if value is None: + return 'NULL' + if isinstance(value, (int, float, bool)): + return str(value) + return "'%s'" % value.replace("'", "''") + + +valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$') + + +def is_valid_name(name): + if name is None: + return False + if name.lower() in _keywords - _unreserved_keywords: + return False + return valid_cql3_word_re.match(name) is not None + + +def maybe_escape_name(name): + if is_valid_name(name): + return name + return escape_name(name) + + +def escape_name(name): + return '"%s"' % (name.replace('"', '""'),) class ColumnMetadata(object): @@ -825,7 +857,11 @@ class IndexMetadata(object): Returns a CQL query that can be used to recreate this index. """ table = self.column.table - return "CREATE INDEX %s ON %s.%s (%s)" % (self.name, table.keyspace.name, table.name, self.column.name) + return "CREATE INDEX %s ON %s.%s (%s)" % ( + self.name, # Cassandra doesn't like quoted index names for some reason + protect_name(table.keyspace.name), + protect_name(table.name), + protect_name(self.column.name)) class TokenMap(object): @@ -890,10 +926,11 @@ class TokenMap(object): if tokens_to_hosts is None: return [] - point = bisect_left(self.ring, token) - if point == 0 and token != self.ring[0]: - return tokens_to_hosts[self.ring[-1]] - elif point == len(self.ring): + # token range ownership is exclusive on the LHS (the start token), so + # we use bisect_right, which, in the case of a tie/exact match, + # picks an insertion point to the right of the existing match + point = bisect_right(self.ring, token) + if point == len(self.ring): return tokens_to_hosts[self.ring[0]] else: return tokens_to_hosts[self.ring[point]] diff --git a/cassandra/policies.py b/cassandra/policies.py index e866a8c4..44fdae23 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -1,6 +1,7 @@ from itertools import islice, cycle, groupby, repeat import logging from random import randint +from threading import Lock from cassandra import ConsistencyLevel @@ -78,6 +79,11 @@ class LoadBalancingPolicy(HostStateListener): custom behavior. """ + _hosts_lock = None + + def __init__(self): + self._hosts_lock = Lock() + def distance(self, host): """ Returns a measure of how remote a :class:`~.pool.Host` is in @@ -130,7 +136,7 @@ class RoundRobinPolicy(LoadBalancingPolicy): """ def populate(self, cluster, hosts): - self._live_hosts = set(hosts) + self._live_hosts = frozenset(hosts) if len(hosts) <= 1: self._position = 0 else: @@ -145,24 +151,29 @@ class RoundRobinPolicy(LoadBalancingPolicy): pos = self._position self._position += 1 - length = len(self._live_hosts) + hosts = self._live_hosts + length = len(hosts) if length: pos %= length - return list(islice(cycle(self._live_hosts), pos, pos + length)) + return list(islice(cycle(hosts), pos, pos + length)) else: return [] def on_up(self, host): - self._live_hosts.add(host) + with self._hosts_lock: + self._live_hosts = self._live_hosts.union((host, )) def on_down(self, host): - self._live_hosts.discard(host) + with self._hosts_lock: + self._live_hosts = self._live_hosts.difference((host, )) def on_add(self, host): - self._live_hosts.add(host) + with self._hosts_lock: + self._live_hosts = self._live_hosts.union((host, )) def on_remove(self, host): - self._live_hosts.remove(host) + with self._hosts_lock: + self._live_hosts = self._live_hosts.difference((host, )) class DCAwareRoundRobinPolicy(LoadBalancingPolicy): @@ -191,13 +202,14 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): self.local_dc = local_dc self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} + LoadBalancingPolicy.__init__(self) def _dc(self, host): return host.datacenter or self.local_dc def populate(self, cluster, hosts): for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): - self._dc_live_hosts[dc] = set(dc_hosts) + self._dc_live_hosts[dc] = frozenset(dc_hosts) # position is currently only used for local hosts local_live = self._dc_live_hosts.get(self.local_dc) @@ -244,16 +256,28 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): yield host def on_up(self, host): - self._dc_live_hosts.setdefault(self._dc(host), set()).add(host) + dc = self._dc(host) + with self._hosts_lock: + current_hosts = self._dc_live_hosts.setdefault(dc, frozenset()) + self._dc_live_hosts[dc] = current_hosts.union((host, )) def on_down(self, host): - self._dc_live_hosts.setdefault(self._dc(host), set()).discard(host) + dc = self._dc(host) + with self._hosts_lock: + current_hosts = self._dc_live_hosts.setdefault(dc, frozenset()) + self._dc_live_hosts[dc] = current_hosts.difference((host, )) def on_add(self, host): - self._dc_live_hosts.setdefault(self._dc(host), set()).add(host) + dc = self._dc(host) + with self._hosts_lock: + current_hosts = self._dc_live_hosts.setdefault(dc, frozenset()) + self._dc_live_hosts[dc] = current_hosts.union((host, )) def on_remove(self, host): - self._dc_live_hosts.setdefault(self._dc(host), set()).discard(host) + dc = self._dc(host) + with self._hosts_lock: + current_hosts = self._dc_live_hosts.setdefault(dc, frozenset()) + self._dc_live_hosts[dc] = current_hosts.difference((host, )) class TokenAwarePolicy(LoadBalancingPolicy): @@ -351,12 +375,10 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy): :param hosts: List of hosts """ self._allowed_hosts = hosts + RoundRobinPolicy.__init__(self) def populate(self, cluster, hosts): - self._live_hosts = set() - for host in hosts: - if host.address in self._allowed_hosts: - self._live_hosts.add(host) + self._live_hosts = frozenset(h for h in hosts if h.address in self._allowed_hosts) if len(hosts) <= 1: self._position = 0 @@ -371,14 +393,11 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy): def on_up(self, host): if host.address in self._allowed_hosts: - self._live_hosts.add(host) + RoundRobinPolicy.on_up(self, host) def on_add(self, host): if host.address in self._allowed_hosts: - self._live_hosts.add(host) - - def on_remove(self, host): - self._live_hosts.discard(host) + RoundRobinPolicy.on_add(self, host) class ConvictionPolicy(object): diff --git a/cassandra/query.py b/cassandra/query.py index 698eb5d6..130c1437 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -8,10 +8,10 @@ from datetime import datetime, timedelta import struct import time -from cassandra import ConsistencyLevel +from cassandra import ConsistencyLevel, OperationTimedOut from cassandra.cqltypes import unix_time_from_uuid1 from cassandra.decoder import (cql_encoders, cql_encode_object, - cql_encode_sequence) + cql_encode_sequence, named_tuple_factory) import logging log = logging.getLogger(__name__) @@ -45,10 +45,8 @@ class Statement(object): _routing_key = None - def __init__(self, retry_policy=None, tracing_enabled=False, - consistency_level=None, routing_key=None): + def __init__(self, retry_policy=None, consistency_level=None, routing_key=None): self.retry_policy = retry_policy - self.tracing_enabled = tracing_enabled if consistency_level is not None: self.consistency_level = consistency_level self._routing_key = routing_key @@ -240,10 +238,9 @@ class BoundStatement(Statement): expected_type = col_type actual_type = type(value) - err = InvalidParameterTypeError(col_name=col_name, - expected_type=expected_type, - actual_type=actual_type) - raise err + message = ('Received an argument of invalid type for column "%s". ' + 'Expected: %s, Got: %s' % (col_name, expected_type, actual_type)) + raise TypeError(message) return self @@ -321,24 +318,6 @@ class TraceUnavailable(Exception): pass -class InvalidParameterTypeError(TypeError): - """ - Raised when a used tries to bind a prepared statement with an argument of an - invalid type. - """ - - def __init__(self, col_name, expected_type, actual_type): - self.col_name = col_name - self.expected_type = expected_type - self.actual_type = actual_type - - values = (self.col_name, self.expected_type, self.actual_type) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s' % values) - - super(InvalidParameterTypeError, self).__init__(message) - - class QueryTrace(object): """ A trace of the duration and events that occurred when executing @@ -409,9 +388,13 @@ class QueryTrace(object): attempt = 0 start = time.time() while True: - if max_wait is not None and time.time() - start >= max_wait: + time_spent = time.time() - start + if max_wait is not None and time_spent >= max_wait: raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) - session_results = self._session.execute(self._SELECT_SESSIONS_FORMAT, (self.trace_id,)) + + session_results = self._execute( + self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait) + if not session_results or session_results[0].duration is None: time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) attempt += 1 @@ -424,11 +407,25 @@ class QueryTrace(object): self.coordinator = session_row.coordinator self.parameters = session_row.parameters - event_results = self._session.execute(self._SELECT_EVENTS_FORMAT, (self.trace_id,)) + time_spent = time.time() - start + event_results = self._execute( + self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait) self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) for r in event_results) break + def _execute(self, query, parameters, time_spent, max_wait): + # in case the user switched the row factory, set it to namedtuple for this query + future = self._session._create_response_future(query, parameters, trace=False) + future.row_factory = named_tuple_factory + future.send_request() + + timeout = (max_wait - time_spent) if max_wait is not None else None + try: + return future.result(timeout=timeout) + except OperationTimedOut: + raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) + def __str__(self): return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ % (self.request_type, self.trace_id, self.coordinator, self.started_at, @@ -471,7 +468,10 @@ class TraceEvent(object): self.description = description self.datetime = datetime.utcfromtimestamp(unix_time_from_uuid1(timeuuid)) self.source = source - self.source_elapsed = timedelta(microseconds=source_elapsed) + if source_elapsed is not None: + self.source_elapsed = timedelta(microseconds=source_elapsed) + else: + self.source_elapsed = None self.thread_name = thread_name def __str__(self): diff --git a/docs/api/index.rst b/docs/api/index.rst index f70dfb54..a8284e2d 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -1,9 +1,6 @@ API Documentation ================= -Cassandra Modules ------------------ - .. toctree:: :maxdepth: 2 diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 15ae34a5..ab269960 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -123,7 +123,7 @@ when you execute: It is translated to the following CQL query: -.. code-block:: SQL +.. code-block:: INSERT INTO users (name, credits, user_id) VALUES ('John O''Reilly', 42, 2644bada-852c-11e3-89fb-e0b9a54a6d93) @@ -297,7 +297,7 @@ There are a few important things to remember when working with callbacks: Setting a Consistency Level -^^^^^^^^^^^^^^^^^^^^^^^^^^^ +--------------------------- The consistency level used for a query determines how many of the replicas of the data you are interacting with need to respond for the query to be considered a success. @@ -345,3 +345,34 @@ is different than for simple, non-prepared statements (although future versions of the driver may use the same placeholders for both). Cassandra 2.0 added support for named placeholders; the 1.0 version of the driver does not support them, but the 2.0 version will. + +Setting a Consistency Level with Prepared Statements +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +To specify a consistency level for prepared statements, you have two options. + +The first is to set a default consistency level for every execution of the +prepared statement: + +.. code-block:: python + + from cassandra import ConsistencyLevel + + cluster = Cluster() + session = cluster.connect("mykeyspace") + user_lookup_stmt = session.prepare("SELECT * FROM users WHERE user_id=?") + user_lookup_stmt.consistency_level = ConsistencyLevel.QUORUM + + # these will both use QUORUM + user1 = session.execute(user_lookup_stmt, [user_id1])[0] + user2 = session.execute(user_lookup_stmt, [user_id2])[0] + +The second option is to create a :class:`~.BoundStatement` from the +:class:`~.PreparedStatement` and binding paramaters and set a consistency +level on that: + +.. code-block:: python + + # override the QUORUM default + user3_lookup = user_lookup_stmt.bind([user_id3] + user3_lookup.consistency_level = ConsistencyLevel.ALL + user3 = session.execute(user3_lookup) diff --git a/docs/index.rst b/docs/index.rst index b5bdea6a..573cd041 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,8 +1,6 @@ Python Cassandra Driver ======================= -Contents: - .. toctree:: :maxdepth: 2 diff --git a/tests/integration/long/test_large_data.py b/tests/integration/long/test_large_data.py index df46f3ae..10e0a8eb 100644 --- a/tests/integration/long/test_large_data.py +++ b/tests/integration/long/test_large_data.py @@ -31,6 +31,7 @@ class LargeDataTests(unittest.TestCase): def make_session_and_keyspace(self): cluster = Cluster() session = cluster.connect() + session.default_timeout = 20.0 # increase the default timeout session.row_factory = dict_factory create_schema(session, self.keyspace) diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index 0b1583f8..c0f8670a 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -138,8 +138,8 @@ class SchemaMetadataTest(unittest.TestCase): self.assertEqual([], tablemeta.clustering_key) self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) - for option in TableMetadata.recognized_options: - self.assertTrue(option in tablemeta.options) + for option in tablemeta.options: + self.assertIn(option, TableMetadata.recognized_options) self.check_create_statement(tablemeta, create_statement) @@ -295,7 +295,7 @@ class TestCodeCoverage(unittest.TestCase): cluster = Cluster() cluster.connect() - self.assertIsInstance(cluster.metadata.export_schema_as_string(), unicode) + self.assertIsInstance(cluster.metadata.export_schema_as_string(), basestring) def test_export_keyspace_schema(self): """ @@ -307,8 +307,47 @@ class TestCodeCoverage(unittest.TestCase): for keyspace in cluster.metadata.keyspaces: keyspace_metadata = cluster.metadata.keyspaces[keyspace] - self.assertIsInstance(keyspace_metadata.export_as_string(), unicode) - self.assertIsInstance(keyspace_metadata.as_cql_query(), unicode) + self.assertIsInstance(keyspace_metadata.export_as_string(), basestring) + self.assertIsInstance(keyspace_metadata.as_cql_query(), basestring) + + def test_case_sensitivity(self): + """ + Test that names that need to be escaped in CREATE statements are + """ + + cluster = Cluster() + session = cluster.connect() + + ksname = 'AnInterestingKeyspace' + cfname = 'AnInterestingTable' + + session.execute(""" + CREATE KEYSPACE "%s" + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + """ % (ksname,)) + session.execute(""" + CREATE TABLE "%s"."%s" ( + k int, + "A" int, + "B" int, + "MyColumn" int, + PRIMARY KEY (k, "A")) + WITH CLUSTERING ORDER BY ("A" DESC) + """ % (ksname, cfname)) + session.execute(""" + CREATE INDEX myindex ON "%s"."%s" ("MyColumn") + """ % (ksname, cfname)) + + ksmeta = cluster.metadata.keyspaces[ksname] + schema = ksmeta.export_as_string() + self.assertIn('CREATE KEYSPACE "AnInterestingKeyspace"', schema) + self.assertIn('CREATE TABLE "AnInterestingKeyspace"."AnInterestingTable"', schema) + self.assertIn('"A" int', schema) + self.assertIn('"B" int', schema) + self.assertIn('"MyColumn" int', schema) + self.assertIn('PRIMARY KEY (k, "A")', schema) + self.assertIn('WITH CLUSTERING ORDER BY ("A" DESC)', schema) + self.assertIn('CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")', schema) def test_already_exists_exceptions(self): """ @@ -365,8 +404,8 @@ class TestCodeCoverage(unittest.TestCase): for i, token in enumerate(ring): self.assertEqual(set(get_replicas('test3rf', token)), set(owners)) - self.assertEqual(set(get_replicas('test2rf', token)), set([owners[i], owners[(i + 1) % 3]])) - self.assertEqual(set(get_replicas('test1rf', token)), set([owners[i]])) + self.assertEqual(set(get_replicas('test2rf', token)), set([owners[(i + 1) % 3], owners[(i + 2) % 3]])) + self.assertEqual(set(get_replicas('test1rf', token)), set([owners[(i + 1) % 3]])) class TokenMetadataTest(unittest.TestCase): @@ -393,12 +432,13 @@ class TokenMetadataTest(unittest.TestCase): token_map = TokenMap(MD5Token, token_to_primary_replica, tokens, metadata) # tokens match node tokens exactly - for token, expected_host in zip(tokens, hosts): + for i, token in enumerate(tokens): + expected_host = hosts[(i + 1) % len(hosts)] replicas = token_map.get_replicas("ks", token) self.assertEqual(set(replicas), set([expected_host])) # shift the tokens back by one - for token, expected_host in zip(tokens[1:], hosts[1:]): + for token, expected_host in zip(tokens, hosts): replicas = token_map.get_replicas("ks", MD5Token(str(token.value - 1))) self.assertEqual(set(replicas), set([expected_host])) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index b6aad8a7..36e69f8b 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -5,6 +5,7 @@ except ImportError: from cassandra.query import PreparedStatement, BoundStatement, ValueSequence, SimpleStatement from cassandra.cluster import Cluster +from cassandra.decoder import dict_factory class QueryTest(unittest.TestCase): @@ -49,6 +50,20 @@ class QueryTest(unittest.TestCase): for event in statement.trace.events: str(event) + def test_trace_ignores_row_factory(self): + cluster = Cluster() + session = cluster.connect() + session.row_factory = dict_factory + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + session.execute(statement, trace=True) + + # Ensure this does not throw an exception + str(statement.trace) + for event in statement.trace.events: + str(event) + class PreparedStatementTests(unittest.TestCase): diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index 94c2706f..a350c457 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -14,6 +14,12 @@ except ImportError: from cassandra import InvalidRequest from cassandra.cluster import Cluster +from cassandra.cqltypes import Int32Type, EMPTY +from cassandra.decoder import dict_factory +try: + from collections import OrderedDict +except ImportError: + from cassandra.util import OrderedDict # noqa from tests.integration import get_server_versions @@ -100,16 +106,8 @@ class TypeTests(unittest.TestCase): for expected, actual in zip(expected_vals, results[0]): self.assertEquals(expected, actual) - def test_basic_types(self): - c = Cluster() - s = c.connect() - s.execute(""" - CREATE KEYSPACE typetests - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'} - """) - s.set_keyspace("typetests") - s.execute(""" - CREATE TABLE mytable ( + create_type_table = """ + CREATE TABLE mytable ( a text, b text, c ascii, @@ -131,7 +129,17 @@ class TypeTests(unittest.TestCase): t varint, PRIMARY KEY (a, b) ) - """) + """ + + def test_basic_types(self): + c = Cluster() + s = c.connect() + s.execute(""" + CREATE KEYSPACE typetests + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'} + """) + s.set_keyspace("typetests") + s.execute(self.create_type_table) v1_uuid = uuid1() v4_uuid = uuid4() @@ -220,6 +228,126 @@ class TypeTests(unittest.TestCase): for expected, actual in zip(expected_vals, results[0]): self.assertEquals(expected, actual) + def test_empty_strings_and_nones(self): + c = Cluster() + s = c.connect() + s.execute(""" + CREATE KEYSPACE test_empty_strings_and_nones + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'} + """) + s.set_keyspace("test_empty_strings_and_nones") + s.execute(self.create_type_table) + + s.execute("INSERT INTO mytable (a, b) VALUES ('a', 'b')") + s.row_factory = dict_factory + results = s.execute(""" + SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable + """) + self.assertTrue(all(x is None for x in results[0].values())) + + prepared = s.prepare(""" + SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable + """) + results = s.execute(prepared.bind(())) + self.assertTrue(all(x is None for x in results[0].values())) + + # insert empty strings for string-like fields and fetch them + s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)", + ('', '', '', [''], {'': 3})) + self.assertEquals( + {'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})}, + s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0]) + + self.assertEquals( + {'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})}, + s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0]) + + # non-string types shouldn't accept empty strings + for col in ('d', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'q', 'r', 't'): + query = "INSERT INTO mytable (a, b, %s) VALUES ('a', 'b', %%s)" % (col, ) + try: + s.execute(query, ['']) + except InvalidRequest: + pass + else: + self.fail("Expected an InvalidRequest error when inserting an " + "emptry string for column %s" % (col, )) + + prepared = s.prepare("INSERT INTO mytable (a, b, %s) VALUES ('a', 'b', ?)" % (col, )) + try: + s.execute(prepared, ['']) + except TypeError: + pass + else: + self.fail("Expected an InvalidRequest error when inserting an " + "emptry string for column %s with a prepared statement" % (col, )) + + # insert values for all columns + values = ['a', 'b', 'a', 1, True, Decimal('1.0'), 0.1, 0.1, + "1.2.3.4", 1, ['a'], set([1]), {'a': 1}, 'a', + datetime.now(), uuid4(), uuid1(), 'a', 1] + s.execute(""" + INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, values) + + # then insert None, which should null them out + null_values = values[:2] + ([None] * (len(values) - 2)) + s.execute(""" + INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, null_values) + + results = s.execute(""" + SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable + """) + self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) + + prepared = s.prepare(""" + SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable + """) + results = s.execute(prepared.bind(())) + self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) + + # do the same thing again, but use a prepared statement to insert the nulls + s.execute(""" + INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, values) + prepared = s.prepare(""" + INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """) + s.execute(prepared, null_values) + + results = s.execute(""" + SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable + """) + self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) + + prepared = s.prepare(""" + SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable + """) + results = s.execute(prepared.bind(())) + self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) + + def test_empty_values(self): + c = Cluster() + s = c.connect() + s.execute(""" + CREATE KEYSPACE test_empty_values + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'} + """) + s.set_keyspace("test_empty_values") + s.execute("CREATE TABLE mytable (a text PRIMARY KEY, b int)") + s.execute("INSERT INTO mytable (a, b) VALUES ('a', blobAsInt(0x))") + try: + Int32Type.support_empty_values = True + results = s.execute("SELECT b FROM mytable WHERE a='a'")[0] + self.assertIs(EMPTY, results.b) + finally: + Int32Type.support_empty_values = False + def test_timezone_aware_datetimes(self): """ Ensure timezone-aware datetimes are converted to timestamps correctly """ try: diff --git a/tests/unit/io/test_asyncorereactor.py b/tests/unit/io/test_asyncorereactor.py index 7f1f0468..41b0df01 100644 --- a/tests/unit/io/test_asyncorereactor.py +++ b/tests/unit/io/test_asyncorereactor.py @@ -4,6 +4,7 @@ except ImportError: import unittest # noqa import errno +import os from StringIO import StringIO import socket from socket import error as socket_error @@ -15,7 +16,7 @@ from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT, from cassandra.decoder import (write_stringmultimap, write_int, write_string, SupportedMessage, ReadyMessage, ServerError) -from cassandra.marshal import uint8_pack, uint32_pack +from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.io.asyncorereactor import AsyncoreConnection @@ -84,6 +85,40 @@ class AsyncoreConnectionTest(unittest.TestCase): c.handle_read() self.assertTrue(c.connected_event.is_set()) + return c + + def test_egain_on_buffer_size(self, *args): + # get a connection that's already fully started + c = self.test_successful_connection() + + header = '\x00\x00\x00\x00' + int32_pack(20000) + responses = [ + header + ('a' * (4096 - len(header))), + 'a' * 4096, + socket_error(errno.EAGAIN), + 'a' * 100, + socket_error(errno.EAGAIN)] + + def side_effect(*args): + response = responses.pop(0) + if isinstance(response, socket_error): + raise response + else: + return response + + c.socket.recv.side_effect = side_effect + c.handle_read() + self.assertEquals(c._total_reqd_bytes, 20000 + len(header)) + # the EAGAIN prevents it from reading the last 100 bytes + c._iobuf.seek(0, os.SEEK_END) + pos = c._iobuf.tell() + self.assertEquals(pos, 4096 + 4096) + + # now tell it to read the last 100 bytes + c.handle_read() + c._iobuf.seek(0, os.SEEK_END) + pos = c._iobuf.tell() + self.assertEquals(pos, 4096 + 4096 + 100) def test_protocol_error(self, *args): c = self.make_connection() diff --git a/tests/unit/io/test_libevreactor.py b/tests/unit/io/test_libevreactor.py index b92ca5b3..408af29a 100644 --- a/tests/unit/io/test_libevreactor.py +++ b/tests/unit/io/test_libevreactor.py @@ -4,6 +4,7 @@ except ImportError: import unittest # noqa import errno +import os from StringIO import StringIO from socket import error as socket_error @@ -14,7 +15,7 @@ from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT, from cassandra.decoder import (write_stringmultimap, write_int, write_string, SupportedMessage, ReadyMessage, ServerError) -from cassandra.marshal import uint8_pack, uint32_pack +from cassandra.marshal import uint8_pack, uint32_pack, int32_pack try: from cassandra.io.libevreactor import LibevConnection @@ -84,6 +85,40 @@ class LibevConnectionTest(unittest.TestCase): c.handle_read(None, 0) self.assertTrue(c.connected_event.is_set()) + return c + + def test_egain_on_buffer_size(self, *args): + # get a connection that's already fully started + c = self.test_successful_connection() + + header = '\x00\x00\x00\x00' + int32_pack(20000) + responses = [ + header + ('a' * (4096 - len(header))), + 'a' * 4096, + socket_error(errno.EAGAIN), + 'a' * 100, + socket_error(errno.EAGAIN)] + + def side_effect(*args): + response = responses.pop(0) + if isinstance(response, socket_error): + raise response + else: + return response + + c._socket.recv.side_effect = side_effect + c.handle_read(None, 0) + self.assertEquals(c._total_reqd_bytes, 20000 + len(header)) + # the EAGAIN prevents it from reading the last 100 bytes + c._iobuf.seek(0, os.SEEK_END) + pos = c._iobuf.tell() + self.assertEquals(pos, 4096 + 4096) + + # now tell it to read the last 100 bytes + c.handle_read(None, 0) + c._iobuf.seek(0, os.SEEK_END) + pos = c._iobuf.tell() + self.assertEquals(pos, 4096 + 4096 + 100) def test_protocol_error(self, *args): c = self.make_connection() diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index d4d66b3e..ce7b20dd 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -4,10 +4,11 @@ except ImportError: import unittest # noqa import cassandra -from cassandra.metadata import (TableMetadata, Murmur3Token, MD5Token, +from cassandra.metadata import (Murmur3Token, MD5Token, BytesToken, ReplicationStrategy, NetworkTopologyStrategy, SimpleStrategy, - LocalStrategy, NoMurmur3) + LocalStrategy, NoMurmur3, protect_name, + protect_names, protect_value, is_valid_name) from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host @@ -132,31 +133,25 @@ class TestStrategies(unittest.TestCase): self.assertItemsEqual(rf3_replicas[MD5Token(200)], [host3, host1, host2]) -class TestTokens(unittest.TestCase): +class TestNameEscaping(unittest.TestCase): def test_protect_name(self): """ - Test TableMetadata.protect_name output + Test cassandra.metadata.protect_name output """ - - table_metadata = TableMetadata('ks_name', 'table_name') - - self.assertEqual(table_metadata.protect_name('tests'), 'tests') - self.assertEqual(table_metadata.protect_name('test\'s'), '"test\'s"') - self.assertEqual(table_metadata.protect_name('test\'s'), "\"test's\"") - self.assertEqual(table_metadata.protect_name('tests ?!@#$%^&*()'), '"tests ?!@#$%^&*()"') - self.assertEqual(table_metadata.protect_name('1'), '"1"') - self.assertEqual(table_metadata.protect_name('1test'), '"1test"') + self.assertEqual(protect_name('tests'), 'tests') + self.assertEqual(protect_name('test\'s'), '"test\'s"') + self.assertEqual(protect_name('test\'s'), "\"test's\"") + self.assertEqual(protect_name('tests ?!@#$%^&*()'), '"tests ?!@#$%^&*()"') + self.assertEqual(protect_name('1'), '"1"') + self.assertEqual(protect_name('1test'), '"1test"') def test_protect_names(self): """ - Test TableMetadata.protect_names output + Test cassandra.metadata.protect_names output """ - - table_metadata = TableMetadata('ks_name', 'table_name') - - self.assertEqual(table_metadata.protect_names(['tests']), ['tests']) - self.assertEqual(table_metadata.protect_names( + self.assertEqual(protect_names(['tests']), ['tests']) + self.assertEqual(protect_names( [ 'tests', 'test\'s', @@ -172,36 +167,33 @@ class TestTokens(unittest.TestCase): def test_protect_value(self): """ - Test TableMetadata.protect_value output + Test cassandra.metadata.protect_value output """ - - table_metadata = TableMetadata('ks_name', 'table_name') - - self.assertEqual(table_metadata.protect_value(True), "True") - self.assertEqual(table_metadata.protect_value(False), "False") - self.assertEqual(table_metadata.protect_value(3.14), '3.14') - self.assertEqual(table_metadata.protect_value(3), '3') - self.assertEqual(table_metadata.protect_value('test'), "'test'") - self.assertEqual(table_metadata.protect_value('test\'s'), "'test''s'") - self.assertEqual(table_metadata.protect_value(None), 'NULL') + self.assertEqual(protect_value(True), "True") + self.assertEqual(protect_value(False), "False") + self.assertEqual(protect_value(3.14), '3.14') + self.assertEqual(protect_value(3), '3') + self.assertEqual(protect_value('test'), "'test'") + self.assertEqual(protect_value('test\'s'), "'test''s'") + self.assertEqual(protect_value(None), 'NULL') def test_is_valid_name(self): """ - Test TableMetadata.is_valid_name output + Test cassandra.metadata.is_valid_name output """ - - table_metadata = TableMetadata('ks_name', 'table_name') - - self.assertEqual(table_metadata.is_valid_name(None), False) - self.assertEqual(table_metadata.is_valid_name('test'), True) - self.assertEqual(table_metadata.is_valid_name('Test'), False) - self.assertEqual(table_metadata.is_valid_name('t_____1'), True) - self.assertEqual(table_metadata.is_valid_name('test1'), True) - self.assertEqual(table_metadata.is_valid_name('1test1'), False) + self.assertEqual(is_valid_name(None), False) + self.assertEqual(is_valid_name('test'), True) + self.assertEqual(is_valid_name('Test'), False) + self.assertEqual(is_valid_name('t_____1'), True) + self.assertEqual(is_valid_name('test1'), True) + self.assertEqual(is_valid_name('1test1'), False) non_valid_keywords = cassandra.metadata._keywords - cassandra.metadata._unreserved_keywords for keyword in non_valid_keywords: - self.assertEqual(table_metadata.is_valid_name(keyword), False) + self.assertEqual(is_valid_name(keyword), False) + + +class TestTokens(unittest.TestCase): def test_murmur3_tokens(self): try: diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index 7706e6bc..3ed66647 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -5,7 +5,6 @@ except ImportError: from cassandra.query import bind_params, ValueSequence from cassandra.query import PreparedStatement, BoundStatement -from cassandra.query import InvalidParameterTypeError from cassandra.cqltypes import Int32Type try: @@ -77,21 +76,12 @@ class BoundStatementTestCase(unittest.TestCase): values = ['nonint', 1] - try: - bound_statement.bind(values) - except InvalidParameterTypeError as e: - self.assertEqual(e.col_name, 'foo1') - self.assertEqual(e.expected_type, Int32Type) - self.assertEqual(e.actual_type, str) - else: - self.fail('Passed invalid type but exception was not thrown') - try: bound_statement.bind(values) except TypeError as e: - self.assertEqual(e.col_name, 'foo1') - self.assertEqual(e.expected_type, Int32Type) - self.assertEqual(e.actual_type, str) + self.assertIn('foo1', str(e)) + self.assertIn('Int32Type', str(e)) + self.assertIn('str', str(e)) else: self.fail('Passed invalid type but exception was not thrown') @@ -99,9 +89,9 @@ class BoundStatementTestCase(unittest.TestCase): try: bound_statement.bind(values) - except InvalidParameterTypeError as e: - self.assertEqual(e.col_name, 'foo2') - self.assertEqual(e.expected_type, Int32Type) - self.assertEqual(e.actual_type, list) + except TypeError as e: + self.assertIn('foo2', str(e)) + self.assertIn('Int32Type', str(e)) + self.assertIn('list', str(e)) else: self.fail('Passed invalid type but exception was not thrown') diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 2512862a..56fd5440 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -5,6 +5,8 @@ except ImportError: from itertools import islice, cycle from mock import Mock +from random import randint +import sys import struct from threading import Thread @@ -88,11 +90,51 @@ class TestRoundRobinPolicy(unittest.TestCase): map(lambda t: t.start(), threads) map(lambda t: t.join(), threads) + def test_thread_safety_during_modification(self): + hosts = range(100) + policy = RoundRobinPolicy() + policy.populate(None, hosts) + + errors = [] + + def check_query_plan(): + try: + for i in xrange(100): + list(policy.make_query_plan()) + except Exception as exc: + errors.append(exc) + + def host_up(): + for i in xrange(1000): + policy.on_up(randint(0, 99)) + + def host_down(): + for i in xrange(1000): + policy.on_down(randint(0, 99)) + + threads = [] + for i in range(5): + threads.append(Thread(target=check_query_plan)) + threads.append(Thread(target=host_up)) + threads.append(Thread(target=host_down)) + + # make the GIL switch after every instruction, maximizing + # the chace of race conditions + original_interval = sys.getcheckinterval() + try: + sys.setcheckinterval(0) + map(lambda t: t.start(), threads) + map(lambda t: t.join(), threads) + finally: + sys.setcheckinterval(original_interval) + + if errors: + self.fail("Saw errors: %s" % (errors,)) + def test_no_live_nodes(self): """ Ensure query plan for a downed cluster will execute without errors """ - hosts = [0, 1, 2, 3] policy = RoundRobinPolicy() policy.populate(None, hosts) @@ -408,6 +450,7 @@ class TokenAwarePolicyTest(unittest.TestCase): qplan = list(policy.make_query_plan()) self.assertEqual(qplan, []) + class ConvictionPolicyTest(unittest.TestCase): def test_not_implemented(self): """