Merge branch 'master' into 2.0
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -14,3 +14,4 @@ nosetests.xml | ||||
| cover/ | ||||
| docs/_build/ | ||||
| tests/integration/ccm | ||||
| setuptools*.tar.gz | ||||
|   | ||||
| @@ -1,6 +1,38 @@ | ||||
| 1.0.2 | ||||
| ===== | ||||
| In Progress | ||||
|  | ||||
| Bug Fixes | ||||
| --------- | ||||
| * With asyncorereactor, correctly handle EAGAIN/EWOULDBLOCK when the message from | ||||
|   Cassandra is a multiple of the read buffer size.  Previously, if no more data | ||||
|   became available to read on the socket, the message would never be processed, | ||||
|   resulting in an OperationTimedOut error. | ||||
| * Double quote keyspace, table and column names that require them (those using | ||||
|   uppercase characters or keywords) when generating CREATE statements through | ||||
|   KeyspaceMetadata and TableMetadata. | ||||
| * Decode TimestampType as DateType.  (Cassandra replaced DateType with | ||||
|   TimestampType to fix sorting of pre-unix epoch dates in CASSANDRA-5723.) | ||||
| * Handle latest table options when parsing the schema and generating | ||||
|   CREATE statements. | ||||
| * Avoid 'Set changed size during iteration' during query plan generation | ||||
|   when hosts go up or down | ||||
|  | ||||
| Other | ||||
| ----- | ||||
| * Remove ignored ``tracing_enabled`` parameter for ``SimpleStatement``.  The | ||||
|   correct way to trace a query is by setting the ``trace`` argument to ``True`` | ||||
|   in ``Session.execute()`` and ``Session.execute_async()``. | ||||
| * Raise TypeError instead of cassandra.query.InvalidParameterTypeError when | ||||
|   a parameter for a prepared statement has the wrong type; remove | ||||
|   cassandra.query.InvalidParameterTypeError. | ||||
| * More consistent type checking for query parameters | ||||
| * Add option to a return special object for empty string values for non-string | ||||
|   columns | ||||
|  | ||||
| 1.0.1 | ||||
| ===== | ||||
| (In Progress) | ||||
| Feb 19, 2014 | ||||
|  | ||||
| Bug Fixes | ||||
| --------- | ||||
| @@ -9,12 +41,24 @@ Bug Fixes | ||||
| * Always close socket when defuncting error'ed connections to avoid a potential | ||||
|   file descriptor leak | ||||
| * Handle "custom" types (such as the replaced DateType) correctly | ||||
| * With libevreactor, correctly handle EAGAIN/EWOULDBLOCK when the message from | ||||
|   Cassandra is a multiple of the read buffer size.  Previously, if no more data | ||||
|   became available to read on the socket, the message would never be processed, | ||||
|   resulting in an OperationTimedOut error. | ||||
| * Don't break tracing when a Session's row_factory is not the default | ||||
|   namedtuple_factory. | ||||
| * Handle data that is already utf8-encoded for UTF8Type values | ||||
| * Fix token-aware routing for tokens that fall before the first node token in | ||||
|   the ring and tokens that exactly match a node's token | ||||
| * Tolerate null source_elapsed values for Trace events.  These may not be | ||||
|   set when events complete after the main operation has already completed. | ||||
|  | ||||
| Other | ||||
| ----- | ||||
| * Skip sending OPTIONS message on connection creation if compression is | ||||
|   disabled or not available and a CQL version has not been explicitly | ||||
|   set | ||||
| * Add details about errors and the last queried host to ``OperationTimedOut`` | ||||
|  | ||||
| 1.0.0 Final | ||||
| =========== | ||||
|   | ||||
| @@ -16,12 +16,6 @@ Releasing | ||||
|   so that it looks like ``(x, y, z, 'post')`` | ||||
| * Commit and push | ||||
|  | ||||
| Running the Tests | ||||
| ================= | ||||
| In order for the extensions to be built and used in the test, run:: | ||||
|  | ||||
|     python setup.py nosetests | ||||
|  | ||||
| Building the Docs | ||||
| ================= | ||||
| Sphinx is required to build the docs. You probably want to install through apt, | ||||
| @@ -48,3 +42,26 @@ For example:: | ||||
|     cp -R docs/_build/1.0.0-beta1/* ~/python-driver-docs/ | ||||
|     cd ~/python-driver-docs | ||||
|     git push origin gh-pages | ||||
|  | ||||
| Running the Tests | ||||
| ================= | ||||
| In order for the extensions to be built and used in the test, run:: | ||||
|  | ||||
|     python setup.py nosetests | ||||
|  | ||||
| You can run a specific test module or package like so:: | ||||
|  | ||||
|     python setup.py nosetests -w tests/unit/ | ||||
|  | ||||
| If you want to test all of python 2.6, 2.7, and pypy, use tox (this is what | ||||
| TravisCI runs):: | ||||
|  | ||||
|     tox | ||||
|  | ||||
| By default, tox only runs the unit tests because I haven't put in the effort | ||||
| to get the integration tests to run on TravicCI.  However, the integration | ||||
| tests should work locally.  To run them, edit the following line in tox.ini:: | ||||
|  | ||||
|     commands = {envpython} setup.py build_ext --inplace nosetests --verbosity=2 tests/unit/ | ||||
|  | ||||
| and change ``tests/unit/`` to ``tests/``. | ||||
|   | ||||
| @@ -9,7 +9,7 @@ class NullHandler(logging.Handler): | ||||
| logging.getLogger('cassandra').addHandler(NullHandler()) | ||||
|  | ||||
|  | ||||
| __version_info__ = (1, 0, 0, 'post') | ||||
| __version_info__ = (1, 0, 1, 'post') | ||||
| __version__ = '.'.join(map(str, __version_info__)) | ||||
|  | ||||
|  | ||||
| @@ -226,4 +226,19 @@ class OperationTimedOut(Exception): | ||||
|     to complete.  This is not an error generated by Cassandra, only | ||||
|     the driver. | ||||
|     """ | ||||
|     pass | ||||
|  | ||||
|     errors = None | ||||
|     """ | ||||
|     A dict of errors keyed by the :class:`~.Host` against which they occurred. | ||||
|     """ | ||||
|  | ||||
|     last_host = None | ||||
|     """ | ||||
|     The last :class:`~.Host` this operation was attempted against. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, errors=None, last_host=None): | ||||
|         self.errors = errors | ||||
|         self.last_host = last_host | ||||
|         message = "errors=%s, last_host=%s" % (self.errors, self.last_host) | ||||
|         Exception.__init__(self, message) | ||||
|   | ||||
| @@ -858,8 +858,8 @@ class Cluster(object): | ||||
|                                       "statement on host %s: %r", host, response) | ||||
|  | ||||
|             log.debug("Done preparing all known prepared statements against host %s", host) | ||||
|         except OperationTimedOut: | ||||
|             log.warn("Timed out trying to prepare all statements on host %s", host) | ||||
|         except OperationTimedOut as timeout: | ||||
|             log.warn("Timed out trying to prepare all statements on host %s: %s", host, timeout) | ||||
|         except (ConnectionException, socket.error) as exc: | ||||
|             log.warn("Error trying to prepare all statements on host %s: %r", host, exc) | ||||
|         except Exception: | ||||
| @@ -1045,6 +1045,12 @@ class Session(object): | ||||
|             ...     log.exception("Operation failed:") | ||||
|  | ||||
|         """ | ||||
|         future = self._create_response_future(query, parameters, trace) | ||||
|         future.send_request() | ||||
|         return future | ||||
|  | ||||
|     def _create_response_future(self, query, parameters, trace): | ||||
|         """ Returns the ResponseFuture before calling send_request() on it """ | ||||
|         prepared_statement = None | ||||
|         if isinstance(query, basestring): | ||||
|             query = SimpleStatement(query) | ||||
| @@ -1066,11 +1072,9 @@ class Session(object): | ||||
|         if trace: | ||||
|             message.tracing = True | ||||
|  | ||||
|         future = ResponseFuture( | ||||
|         return ResponseFuture( | ||||
|             self, message, query, self.default_timeout, metrics=self._metrics, | ||||
|             prepared_statement=prepared_statement) | ||||
|         future.send_request() | ||||
|         return future | ||||
|  | ||||
|     def prepare(self, query): | ||||
|         """ | ||||
| @@ -1650,8 +1654,9 @@ class ControlConnection(object): | ||||
|                     timeout = min(2.0, total_timeout - elapsed) | ||||
|                     peers_result, local_result = connection.wait_for_responses( | ||||
|                         peers_query, local_query, timeout=timeout) | ||||
|                 except OperationTimedOut: | ||||
|                     log.debug("[control connection] Timed out waiting for response during schema agreement check") | ||||
|                 except OperationTimedOut as timeout: | ||||
|                     log.debug("[control connection] Timed out waiting for " \ | ||||
|                               "response during schema agreement check: %s", timeout) | ||||
|                     elapsed = self._time.time() - start | ||||
|                     continue | ||||
|  | ||||
| @@ -2195,7 +2200,7 @@ class ResponseFuture(object): | ||||
|             elif self._final_exception: | ||||
|                 raise self._final_exception | ||||
|             else: | ||||
|                 raise OperationTimedOut() | ||||
|                 raise OperationTimedOut(errors=self._errors, last_host=self._current_host) | ||||
|  | ||||
|     def get_query_trace(self, max_wait=None): | ||||
|         """ | ||||
|   | ||||
| @@ -152,12 +152,35 @@ def lookup_casstype(casstype): | ||||
|         raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e)) | ||||
|  | ||||
|  | ||||
| class EmptyValue(object): | ||||
|     """ See _CassandraType.support_empty_values """ | ||||
|  | ||||
|     def __str__(self): | ||||
|         return "EMPTY" | ||||
|     __repr__ = __str__ | ||||
|  | ||||
| EMPTY = EmptyValue() | ||||
|  | ||||
|  | ||||
| class _CassandraType(object): | ||||
|     __metaclass__ = CassandraTypeType | ||||
|     subtypes = () | ||||
|     num_subtypes = 0 | ||||
|     empty_binary_ok = False | ||||
|  | ||||
|     support_empty_values = False | ||||
|     """ | ||||
|     Back in the Thrift days, empty strings were used for "null" values of | ||||
|     all types, including non-string types.  For most users, an empty | ||||
|     string value in an int column is the same as being null/not present, | ||||
|     so the driver normally returns None in this case.  (For string-like | ||||
|     types, it *will* return an empty string by default instead of None.) | ||||
|  | ||||
|     To avoid this behavior, set this to :const:`True`. Instead of returning | ||||
|     None for empty string values, the EMPTY singleton (an instance | ||||
|     of EmptyValue) will be returned. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, val): | ||||
|         self.val = self.validate(val) | ||||
|  | ||||
| @@ -181,8 +204,10 @@ class _CassandraType(object): | ||||
|         for more information. This method differs in that if None or the empty | ||||
|         string is passed in, None may be returned. | ||||
|         """ | ||||
|         if byts is None or (byts == '' and not cls.empty_binary_ok): | ||||
|         if byts is None: | ||||
|             return None | ||||
|         elif byts == '' and not cls.empty_binary_ok: | ||||
|             return EMPTY if cls.support_empty_values else None | ||||
|         return cls.deserialize(byts) | ||||
|  | ||||
|     @classmethod | ||||
| @@ -315,7 +340,10 @@ class DecimalType(_CassandraType): | ||||
|  | ||||
|     @staticmethod | ||||
|     def serialize(dec): | ||||
|         sign, digits, exponent = dec.as_tuple() | ||||
|         try: | ||||
|             sign, digits, exponent = dec.as_tuple() | ||||
|         except AttributeError: | ||||
|             raise TypeError("Non-Decimal type received for Decimal value") | ||||
|         unscaled = int(''.join([str(digit) for digit in digits])) | ||||
|         if sign: | ||||
|             unscaled *= -1 | ||||
| @@ -333,7 +361,10 @@ class UUIDType(_CassandraType): | ||||
|  | ||||
|     @staticmethod | ||||
|     def serialize(uuid): | ||||
|         return uuid.bytes | ||||
|         try: | ||||
|             return uuid.bytes | ||||
|         except AttributeError: | ||||
|             raise TypeError("Got a non-UUID object for a UUID value") | ||||
|  | ||||
|  | ||||
| class BooleanType(_CassandraType): | ||||
| @@ -349,7 +380,7 @@ class BooleanType(_CassandraType): | ||||
|  | ||||
|     @staticmethod | ||||
|     def serialize(truth): | ||||
|         return int8_pack(bool(truth)) | ||||
|         return int8_pack(truth) | ||||
|  | ||||
|  | ||||
| class AsciiType(_CassandraType): | ||||
| @@ -503,6 +534,10 @@ class DateType(_CassandraType): | ||||
|         return int64_pack(long(converted)) | ||||
|  | ||||
|  | ||||
| class TimestampType(DateType): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class TimeUUIDType(DateType): | ||||
|     typename = 'timeuuid' | ||||
|  | ||||
| @@ -515,7 +550,10 @@ class TimeUUIDType(DateType): | ||||
|  | ||||
|     @staticmethod | ||||
|     def serialize(timeuuid): | ||||
|         return timeuuid.bytes | ||||
|         try: | ||||
|             return timeuuid.bytes | ||||
|         except AttributeError: | ||||
|             raise TypeError("Got a non-UUID object for a UUID value") | ||||
|  | ||||
|  | ||||
| class UTF8Type(_CassandraType): | ||||
| @@ -528,7 +566,11 @@ class UTF8Type(_CassandraType): | ||||
|  | ||||
|     @staticmethod | ||||
|     def serialize(ustr): | ||||
|         return ustr.encode('utf8') | ||||
|         try: | ||||
|             return ustr.encode('utf-8') | ||||
|         except UnicodeDecodeError: | ||||
|             # already utf-8 | ||||
|             return ustr | ||||
|  | ||||
|  | ||||
| class VarcharType(UTF8Type): | ||||
| @@ -578,6 +620,9 @@ class _SimpleParameterizedType(_ParameterizedType): | ||||
|  | ||||
|     @classmethod | ||||
|     def serialize_safe(cls, items): | ||||
|         if isinstance(items, basestring): | ||||
|             raise TypeError("Received a string for a type that expects a sequence") | ||||
|  | ||||
|         subtype, = cls.subtypes | ||||
|         buf = StringIO() | ||||
|         buf.write(uint16_pack(len(items))) | ||||
| @@ -634,7 +679,11 @@ class MapType(_ParameterizedType): | ||||
|         subkeytype, subvaltype = cls.subtypes | ||||
|         buf = StringIO() | ||||
|         buf.write(uint16_pack(len(themap))) | ||||
|         for key, val in themap.iteritems(): | ||||
|         try: | ||||
|             items = themap.iteritems() | ||||
|         except AttributeError: | ||||
|             raise TypeError("Got a non-map object for a map value") | ||||
|         for key, val in items: | ||||
|             keybytes = subkeytype.to_binary(key) | ||||
|             valbytes = subvaltype.to_binary(val) | ||||
|             buf.write(uint16_pack(len(keybytes))) | ||||
|   | ||||
| @@ -260,7 +260,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): | ||||
|         except socket.error as err: | ||||
|             if err.args[0] not in NONBLOCKING: | ||||
|                 self.defunct(err) | ||||
|             return | ||||
|                 return | ||||
|  | ||||
|         if self._iobuf.tell(): | ||||
|             while True: | ||||
|   | ||||
| @@ -313,7 +313,7 @@ class LibevConnection(Connection): | ||||
|         except socket.error as err: | ||||
|             if err.args[0] not in NONBLOCKING: | ||||
|                 self.defunct(err) | ||||
|             return | ||||
|                 return | ||||
|  | ||||
|         if self._iobuf.tell(): | ||||
|             while True: | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| from bisect import bisect_left | ||||
| from bisect import bisect_right | ||||
| from collections import defaultdict | ||||
| try: | ||||
|     from collections import OrderedDict | ||||
| @@ -274,13 +274,13 @@ class Metadata(object): | ||||
|                 column_meta = self._build_column_metadata(table_meta, col_row) | ||||
|                 table_meta.columns[column_meta.name] = column_meta | ||||
|  | ||||
|         table_meta.options = self._build_table_options(row, is_compact) | ||||
|         table_meta.options = self._build_table_options(row) | ||||
|         table_meta.is_compact_storage = is_compact | ||||
|         return table_meta | ||||
|  | ||||
|     def _build_table_options(self, row, is_compact_storage): | ||||
|     def _build_table_options(self, row): | ||||
|         """ Setup the mostly-non-schema table options, like caching settings """ | ||||
|         options = dict((o, row.get(o)) for o in TableMetadata.recognized_options) | ||||
|         options["is_compact_storage"] = is_compact_storage | ||||
|         options = dict((o, row.get(o)) for o in TableMetadata.recognized_options if o in row) | ||||
|         return options | ||||
|  | ||||
|     def _build_column_metadata(self, table_metadata, row): | ||||
| @@ -564,9 +564,10 @@ class KeyspaceMetadata(object): | ||||
|         return "\n".join([self.as_cql_query()] + [t.export_as_string() for t in self.tables.values()]) | ||||
|  | ||||
|     def as_cql_query(self): | ||||
|         ret = "CREATE KEYSPACE %s WITH REPLICATION = %s " % \ | ||||
|               (self.name, self.replication_strategy.export_for_schema()) | ||||
|         return ret + (' AND DURABLE_WRITES = %s;' % ("true" if self.durable_writes else "false")) | ||||
|         ret = "CREATE KEYSPACE %s WITH replication = %s " % ( | ||||
|             protect_name(self.name), | ||||
|             self.replication_strategy.export_for_schema()) | ||||
|         return ret + (' AND durable_writes = %s;' % ("true" if self.durable_writes else "false")) | ||||
|  | ||||
|  | ||||
| class TableMetadata(object): | ||||
| @@ -610,6 +611,8 @@ class TableMetadata(object): | ||||
|     A dict mapping column names to :class:`.ColumnMetadata` instances. | ||||
|     """ | ||||
|  | ||||
|     is_compact_storage = False | ||||
|  | ||||
|     options = None | ||||
|     """ | ||||
|     A dict mapping table option names to their specific settings for this | ||||
| @@ -617,11 +620,28 @@ class TableMetadata(object): | ||||
|     """ | ||||
|  | ||||
|     recognized_options = ( | ||||
|             "comment", "read_repair_chance",  # "local_read_repair_chance", | ||||
|             "replicate_on_write", "gc_grace_seconds", "bloom_filter_fp_chance", | ||||
|             "caching", "compaction_strategy_class", "compaction_strategy_options", | ||||
|             "min_compaction_threshold", "max_compression_threshold", | ||||
|             "compression_parameters") | ||||
|         "comment", | ||||
|         "read_repair_chance", | ||||
|         "dclocal_read_repair_chance", | ||||
|         "replicate_on_write", | ||||
|         "gc_grace_seconds", | ||||
|         "bloom_filter_fp_chance", | ||||
|         "caching", | ||||
|         "compaction_strategy_class", | ||||
|         "compaction_strategy_options", | ||||
|         "min_compaction_threshold", | ||||
|         "max_compression_threshold", | ||||
|         "compression_parameters", | ||||
|         "min_index_interval", | ||||
|         "max_index_interval", | ||||
|         "index_interval", | ||||
|         "speculative_retry", | ||||
|         "rows_per_partition_to_cache", | ||||
|         "memtable_flush_period_in_ms", | ||||
|         "populate_io_cache_on_flush", | ||||
|         "compaction", | ||||
|         "compression", | ||||
|         "default_time_to_live") | ||||
|  | ||||
|     def __init__(self, keyspace_metadata, name, partition_key=None, clustering_key=None, columns=None, options=None): | ||||
|         self.keyspace = keyspace_metadata | ||||
| @@ -653,7 +673,10 @@ class TableMetadata(object): | ||||
|         creations are not included).  If `formatted` is set to :const:`True`, | ||||
|         extra whitespace will be added to make the query human readable. | ||||
|         """ | ||||
|         ret = "CREATE TABLE %s.%s (%s" % (self.keyspace.name, self.name, "\n" if formatted else "") | ||||
|         ret = "CREATE TABLE %s.%s (%s" % ( | ||||
|             protect_name(self.keyspace.name), | ||||
|             protect_name(self.name), | ||||
|             "\n" if formatted else "") | ||||
|  | ||||
|         if formatted: | ||||
|             column_join = ",\n" | ||||
| @@ -664,7 +687,7 @@ class TableMetadata(object): | ||||
|  | ||||
|         columns = [] | ||||
|         for col in self.columns.values(): | ||||
|             columns.append("%s %s" % (col.name, col.typestring)) | ||||
|             columns.append("%s %s" % (protect_name(col.name), col.typestring)) | ||||
|  | ||||
|         if len(self.partition_key) == 1 and not self.clustering_key: | ||||
|             columns[0] += " PRIMARY KEY" | ||||
| @@ -676,12 +699,12 @@ class TableMetadata(object): | ||||
|             ret += "%s%sPRIMARY KEY (" % (column_join, padding) | ||||
|  | ||||
|             if len(self.partition_key) > 1: | ||||
|                 ret += "(%s)" % ", ".join(col.name for col in self.partition_key) | ||||
|                 ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key) | ||||
|             else: | ||||
|                 ret += self.partition_key[0].name | ||||
|  | ||||
|             if self.clustering_key: | ||||
|                 ret += ", %s" % ", ".join(col.name for col in self.clustering_key) | ||||
|                 ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key) | ||||
|  | ||||
|             ret += ")" | ||||
|  | ||||
| @@ -689,15 +712,15 @@ class TableMetadata(object): | ||||
|         ret += "%s) WITH " % ("\n" if formatted else "") | ||||
|  | ||||
|         option_strings = [] | ||||
|         if self.options.get("is_compact_storage"): | ||||
|         if self.is_compact_storage: | ||||
|             option_strings.append("COMPACT STORAGE") | ||||
|  | ||||
|         if self.clustering_key: | ||||
|             cluster_str = "CLUSTERING ORDER BY " | ||||
|  | ||||
|             clustering_names = self.protect_names([c.name for c in self.clustering_key]) | ||||
|             clustering_names = protect_names([c.name for c in self.clustering_key]) | ||||
|  | ||||
|             if self.options.get("is_compact_storage") and \ | ||||
|             if self.is_compact_storage and \ | ||||
|                     not issubclass(self.comparator, types.CompositeType): | ||||
|                 subtypes = [self.comparator] | ||||
|             else: | ||||
| @@ -711,52 +734,61 @@ class TableMetadata(object): | ||||
|             cluster_str += "(%s)" % ", ".join(inner) | ||||
|             option_strings.append(cluster_str) | ||||
|  | ||||
|         option_strings.extend(map(self._make_option_str, self.recognized_options)) | ||||
|         option_strings = filter(lambda x: x is not None, option_strings) | ||||
|         option_strings.extend(self._make_option_strings()) | ||||
|  | ||||
|         join_str = "\n    AND " if formatted else " AND " | ||||
|         ret += join_str.join(option_strings) | ||||
|  | ||||
|         return ret | ||||
|  | ||||
|     def _make_option_str(self, name): | ||||
|         value = self.options.get(name) | ||||
|         if value is not None: | ||||
|             if name == "comment": | ||||
|                 value = value or "" | ||||
|             return "%s = %s" % (name, self.protect_value(value)) | ||||
|     def _make_option_strings(self): | ||||
|         ret = [] | ||||
|         for name, value in sorted(self.options.items()): | ||||
|             if value is not None: | ||||
|                 if name == "comment": | ||||
|                     value = value or "" | ||||
|                 ret.append("%s = %s" % (name, protect_value(value))) | ||||
|  | ||||
|     def protect_name(self, name): | ||||
|         if isinstance(name, unicode): | ||||
|             name = name.encode('utf8') | ||||
|         return self.maybe_escape_name(name) | ||||
|         return ret | ||||
|  | ||||
|     def protect_names(self, names): | ||||
|         return map(self.protect_name, names) | ||||
|  | ||||
|     def protect_value(self, value): | ||||
|         if value is None: | ||||
|             return 'NULL' | ||||
|         if isinstance(value, (int, float, bool)): | ||||
|             return str(value) | ||||
|         return "'%s'" % value.replace("'", "''") | ||||
| def protect_name(name): | ||||
|     if isinstance(name, unicode): | ||||
|         name = name.encode('utf8') | ||||
|     return maybe_escape_name(name) | ||||
|  | ||||
|     valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$') | ||||
|  | ||||
|     def is_valid_name(self, name): | ||||
|         if name is None: | ||||
|             return False | ||||
|         if name.lower() in _keywords - _unreserved_keywords: | ||||
|             return False | ||||
|         return self.valid_cql3_word_re.match(name) is not None | ||||
| def protect_names(names): | ||||
|     return map(protect_name, names) | ||||
|  | ||||
|     def maybe_escape_name(self, name): | ||||
|         if self.is_valid_name(name): | ||||
|             return name | ||||
|         return self.escape_name(name) | ||||
|  | ||||
|     def escape_name(self, name): | ||||
|         return '"%s"' % (name.replace('"', '""'),) | ||||
| def protect_value(value): | ||||
|     if value is None: | ||||
|         return 'NULL' | ||||
|     if isinstance(value, (int, float, bool)): | ||||
|         return str(value) | ||||
|     return "'%s'" % value.replace("'", "''") | ||||
|  | ||||
|  | ||||
| valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$') | ||||
|  | ||||
|  | ||||
| def is_valid_name(name): | ||||
|     if name is None: | ||||
|         return False | ||||
|     if name.lower() in _keywords - _unreserved_keywords: | ||||
|         return False | ||||
|     return valid_cql3_word_re.match(name) is not None | ||||
|  | ||||
|  | ||||
| def maybe_escape_name(name): | ||||
|     if is_valid_name(name): | ||||
|         return name | ||||
|     return escape_name(name) | ||||
|  | ||||
|  | ||||
| def escape_name(name): | ||||
|     return '"%s"' % (name.replace('"', '""'),) | ||||
|  | ||||
|  | ||||
| class ColumnMetadata(object): | ||||
| @@ -825,7 +857,11 @@ class IndexMetadata(object): | ||||
|         Returns a CQL query that can be used to recreate this index. | ||||
|         """ | ||||
|         table = self.column.table | ||||
|         return "CREATE INDEX %s ON %s.%s (%s)" % (self.name, table.keyspace.name, table.name, self.column.name) | ||||
|         return "CREATE INDEX %s ON %s.%s (%s)" % ( | ||||
|             self.name,  # Cassandra doesn't like quoted index names for some reason | ||||
|             protect_name(table.keyspace.name), | ||||
|             protect_name(table.name), | ||||
|             protect_name(self.column.name)) | ||||
|  | ||||
|  | ||||
| class TokenMap(object): | ||||
| @@ -890,10 +926,11 @@ class TokenMap(object): | ||||
|             if tokens_to_hosts is None: | ||||
|                 return [] | ||||
|  | ||||
|         point = bisect_left(self.ring, token) | ||||
|         if point == 0 and token != self.ring[0]: | ||||
|             return tokens_to_hosts[self.ring[-1]] | ||||
|         elif point == len(self.ring): | ||||
|         # token range ownership is exclusive on the LHS (the start token), so | ||||
|         # we use bisect_right, which, in the case of a tie/exact match, | ||||
|         # picks an insertion point to the right of the existing match | ||||
|         point = bisect_right(self.ring, token) | ||||
|         if point == len(self.ring): | ||||
|             return tokens_to_hosts[self.ring[0]] | ||||
|         else: | ||||
|             return tokens_to_hosts[self.ring[point]] | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| from itertools import islice, cycle, groupby, repeat | ||||
| import logging | ||||
| from random import randint | ||||
| from threading import Lock | ||||
|  | ||||
| from cassandra import ConsistencyLevel | ||||
|  | ||||
| @@ -78,6 +79,11 @@ class LoadBalancingPolicy(HostStateListener): | ||||
|     custom behavior. | ||||
|     """ | ||||
|  | ||||
|     _hosts_lock = None | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._hosts_lock = Lock() | ||||
|  | ||||
|     def distance(self, host): | ||||
|         """ | ||||
|         Returns a measure of how remote a :class:`~.pool.Host` is in | ||||
| @@ -130,7 +136,7 @@ class RoundRobinPolicy(LoadBalancingPolicy): | ||||
|     """ | ||||
|  | ||||
|     def populate(self, cluster, hosts): | ||||
|         self._live_hosts = set(hosts) | ||||
|         self._live_hosts = frozenset(hosts) | ||||
|         if len(hosts) <= 1: | ||||
|             self._position = 0 | ||||
|         else: | ||||
| @@ -145,24 +151,29 @@ class RoundRobinPolicy(LoadBalancingPolicy): | ||||
|         pos = self._position | ||||
|         self._position += 1 | ||||
|  | ||||
|         length = len(self._live_hosts) | ||||
|         hosts = self._live_hosts | ||||
|         length = len(hosts) | ||||
|         if length: | ||||
|             pos %= length | ||||
|             return list(islice(cycle(self._live_hosts), pos, pos + length)) | ||||
|             return list(islice(cycle(hosts), pos, pos + length)) | ||||
|         else: | ||||
|             return [] | ||||
|  | ||||
|     def on_up(self, host): | ||||
|         self._live_hosts.add(host) | ||||
|         with self._hosts_lock: | ||||
|             self._live_hosts = self._live_hosts.union((host, )) | ||||
|  | ||||
|     def on_down(self, host): | ||||
|         self._live_hosts.discard(host) | ||||
|         with self._hosts_lock: | ||||
|             self._live_hosts = self._live_hosts.difference((host, )) | ||||
|  | ||||
|     def on_add(self, host): | ||||
|         self._live_hosts.add(host) | ||||
|         with self._hosts_lock: | ||||
|             self._live_hosts = self._live_hosts.union((host, )) | ||||
|  | ||||
|     def on_remove(self, host): | ||||
|         self._live_hosts.remove(host) | ||||
|         with self._hosts_lock: | ||||
|             self._live_hosts = self._live_hosts.difference((host, )) | ||||
|  | ||||
|  | ||||
| class DCAwareRoundRobinPolicy(LoadBalancingPolicy): | ||||
| @@ -191,13 +202,14 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): | ||||
|         self.local_dc = local_dc | ||||
|         self.used_hosts_per_remote_dc = used_hosts_per_remote_dc | ||||
|         self._dc_live_hosts = {} | ||||
|         LoadBalancingPolicy.__init__(self) | ||||
|  | ||||
|     def _dc(self, host): | ||||
|         return host.datacenter or self.local_dc | ||||
|  | ||||
|     def populate(self, cluster, hosts): | ||||
|         for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): | ||||
|             self._dc_live_hosts[dc] = set(dc_hosts) | ||||
|             self._dc_live_hosts[dc] = frozenset(dc_hosts) | ||||
|  | ||||
|         # position is currently only used for local hosts | ||||
|         local_live = self._dc_live_hosts.get(self.local_dc) | ||||
| @@ -244,16 +256,28 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): | ||||
|                 yield host | ||||
|  | ||||
|     def on_up(self, host): | ||||
|         self._dc_live_hosts.setdefault(self._dc(host), set()).add(host) | ||||
|         dc = self._dc(host) | ||||
|         with self._hosts_lock: | ||||
|             current_hosts = self._dc_live_hosts.setdefault(dc, frozenset()) | ||||
|             self._dc_live_hosts[dc] = current_hosts.union((host, )) | ||||
|  | ||||
|     def on_down(self, host): | ||||
|         self._dc_live_hosts.setdefault(self._dc(host), set()).discard(host) | ||||
|         dc = self._dc(host) | ||||
|         with self._hosts_lock: | ||||
|             current_hosts = self._dc_live_hosts.setdefault(dc, frozenset()) | ||||
|             self._dc_live_hosts[dc] = current_hosts.difference((host, )) | ||||
|  | ||||
|     def on_add(self, host): | ||||
|         self._dc_live_hosts.setdefault(self._dc(host), set()).add(host) | ||||
|         dc = self._dc(host) | ||||
|         with self._hosts_lock: | ||||
|             current_hosts = self._dc_live_hosts.setdefault(dc, frozenset()) | ||||
|             self._dc_live_hosts[dc] = current_hosts.union((host, )) | ||||
|  | ||||
|     def on_remove(self, host): | ||||
|         self._dc_live_hosts.setdefault(self._dc(host), set()).discard(host) | ||||
|         dc = self._dc(host) | ||||
|         with self._hosts_lock: | ||||
|             current_hosts = self._dc_live_hosts.setdefault(dc, frozenset()) | ||||
|             self._dc_live_hosts[dc] = current_hosts.difference((host, )) | ||||
|  | ||||
|  | ||||
| class TokenAwarePolicy(LoadBalancingPolicy): | ||||
| @@ -351,12 +375,10 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy): | ||||
|         :param hosts: List of hosts | ||||
|         """ | ||||
|         self._allowed_hosts = hosts | ||||
|         RoundRobinPolicy.__init__(self) | ||||
|  | ||||
|     def populate(self, cluster, hosts): | ||||
|         self._live_hosts = set() | ||||
|         for host in hosts: | ||||
|             if host.address in self._allowed_hosts: | ||||
|                 self._live_hosts.add(host) | ||||
|         self._live_hosts = frozenset(h for h in hosts if h.address in self._allowed_hosts) | ||||
|  | ||||
|         if len(hosts) <= 1: | ||||
|             self._position = 0 | ||||
| @@ -371,14 +393,11 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy): | ||||
|  | ||||
|     def on_up(self, host): | ||||
|         if host.address in self._allowed_hosts: | ||||
|             self._live_hosts.add(host) | ||||
|             RoundRobinPolicy.on_up(self, host) | ||||
|  | ||||
|     def on_add(self, host): | ||||
|         if host.address in self._allowed_hosts: | ||||
|             self._live_hosts.add(host) | ||||
|  | ||||
|     def on_remove(self, host): | ||||
|         self._live_hosts.discard(host) | ||||
|             RoundRobinPolicy.on_add(self, host) | ||||
|  | ||||
|  | ||||
| class ConvictionPolicy(object): | ||||
|   | ||||
| @@ -8,10 +8,10 @@ from datetime import datetime, timedelta | ||||
| import struct | ||||
| import time | ||||
|  | ||||
| from cassandra import ConsistencyLevel | ||||
| from cassandra import ConsistencyLevel, OperationTimedOut | ||||
| from cassandra.cqltypes import unix_time_from_uuid1 | ||||
| from cassandra.decoder import (cql_encoders, cql_encode_object, | ||||
|                                cql_encode_sequence) | ||||
|                                cql_encode_sequence, named_tuple_factory) | ||||
|  | ||||
| import logging | ||||
| log = logging.getLogger(__name__) | ||||
| @@ -45,10 +45,8 @@ class Statement(object): | ||||
|  | ||||
|     _routing_key = None | ||||
|  | ||||
|     def __init__(self, retry_policy=None, tracing_enabled=False, | ||||
|                  consistency_level=None, routing_key=None): | ||||
|     def __init__(self, retry_policy=None, consistency_level=None, routing_key=None): | ||||
|         self.retry_policy = retry_policy | ||||
|         self.tracing_enabled = tracing_enabled | ||||
|         if consistency_level is not None: | ||||
|             self.consistency_level = consistency_level | ||||
|         self._routing_key = routing_key | ||||
| @@ -240,10 +238,9 @@ class BoundStatement(Statement): | ||||
|                     expected_type = col_type | ||||
|                     actual_type = type(value) | ||||
|  | ||||
|                     err = InvalidParameterTypeError(col_name=col_name, | ||||
|                                                     expected_type=expected_type, | ||||
|                                                     actual_type=actual_type) | ||||
|                     raise err | ||||
|                     message = ('Received an argument of invalid type for column "%s". ' | ||||
|                                'Expected: %s, Got: %s' % (col_name, expected_type, actual_type)) | ||||
|                     raise TypeError(message) | ||||
|  | ||||
|         return self | ||||
|  | ||||
| @@ -321,24 +318,6 @@ class TraceUnavailable(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class InvalidParameterTypeError(TypeError): | ||||
|     """ | ||||
|     Raised when a used tries to bind a prepared statement with an argument of an | ||||
|     invalid type. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, col_name, expected_type, actual_type): | ||||
|         self.col_name = col_name | ||||
|         self.expected_type = expected_type | ||||
|         self.actual_type = actual_type | ||||
|  | ||||
|         values = (self.col_name, self.expected_type, self.actual_type) | ||||
|         message = ('Received an argument of invalid type for column "%s". ' | ||||
|                    'Expected: %s, Got: %s' % values) | ||||
|  | ||||
|         super(InvalidParameterTypeError, self).__init__(message) | ||||
|  | ||||
|  | ||||
| class QueryTrace(object): | ||||
|     """ | ||||
|     A trace of the duration and events that occurred when executing | ||||
| @@ -409,9 +388,13 @@ class QueryTrace(object): | ||||
|         attempt = 0 | ||||
|         start = time.time() | ||||
|         while True: | ||||
|             if max_wait is not None and time.time() - start >= max_wait: | ||||
|             time_spent = time.time() - start | ||||
|             if max_wait is not None and time_spent >= max_wait: | ||||
|                 raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) | ||||
|             session_results = self._session.execute(self._SELECT_SESSIONS_FORMAT, (self.trace_id,)) | ||||
|  | ||||
|             session_results = self._execute( | ||||
|                 self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait) | ||||
|  | ||||
|             if not session_results or session_results[0].duration is None: | ||||
|                 time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) | ||||
|                 attempt += 1 | ||||
| @@ -424,11 +407,25 @@ class QueryTrace(object): | ||||
|             self.coordinator = session_row.coordinator | ||||
|             self.parameters = session_row.parameters | ||||
|  | ||||
|             event_results = self._session.execute(self._SELECT_EVENTS_FORMAT, (self.trace_id,)) | ||||
|             time_spent = time.time() - start | ||||
|             event_results = self._execute( | ||||
|                 self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait) | ||||
|             self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) | ||||
|                                 for r in event_results) | ||||
|             break | ||||
|  | ||||
|     def _execute(self, query, parameters, time_spent, max_wait): | ||||
|         # in case the user switched the row factory, set it to namedtuple for this query | ||||
|         future = self._session._create_response_future(query, parameters, trace=False) | ||||
|         future.row_factory = named_tuple_factory | ||||
|         future.send_request() | ||||
|  | ||||
|         timeout = (max_wait - time_spent) if max_wait is not None else None | ||||
|         try: | ||||
|             return future.result(timeout=timeout) | ||||
|         except OperationTimedOut: | ||||
|             raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ | ||||
|                % (self.request_type, self.trace_id, self.coordinator, self.started_at, | ||||
| @@ -471,7 +468,10 @@ class TraceEvent(object): | ||||
|         self.description = description | ||||
|         self.datetime = datetime.utcfromtimestamp(unix_time_from_uuid1(timeuuid)) | ||||
|         self.source = source | ||||
|         self.source_elapsed = timedelta(microseconds=source_elapsed) | ||||
|         if source_elapsed is not None: | ||||
|             self.source_elapsed = timedelta(microseconds=source_elapsed) | ||||
|         else: | ||||
|             self.source_elapsed = None | ||||
|         self.thread_name = thread_name | ||||
|  | ||||
|     def __str__(self): | ||||
|   | ||||
| @@ -1,9 +1,6 @@ | ||||
| API Documentation | ||||
| ================= | ||||
|  | ||||
| Cassandra Modules | ||||
| ----------------- | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 2 | ||||
|  | ||||
|   | ||||
| @@ -123,7 +123,7 @@ when you execute: | ||||
|  | ||||
| It is translated to the following CQL query: | ||||
|  | ||||
| .. code-block:: SQL | ||||
| .. code-block:: | ||||
|  | ||||
|     INSERT INTO users (name, credits, user_id) | ||||
|     VALUES ('John O''Reilly', 42, 2644bada-852c-11e3-89fb-e0b9a54a6d93) | ||||
| @@ -297,7 +297,7 @@ There are a few important things to remember when working with callbacks: | ||||
|  | ||||
|  | ||||
| Setting a Consistency Level | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
| --------------------------- | ||||
| The consistency level used for a query determines how many of the | ||||
| replicas of the data you are interacting with need to respond for | ||||
| the query to be considered a success. | ||||
| @@ -345,3 +345,34 @@ is different than for simple, non-prepared statements (although future versions | ||||
| of the driver may use the same placeholders for both).  Cassandra 2.0 added | ||||
| support for named placeholders; the 1.0 version of the driver does not support | ||||
| them, but the 2.0 version will. | ||||
|  | ||||
| Setting a Consistency Level with Prepared Statements | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
| To specify a consistency level for prepared statements, you have two options. | ||||
|  | ||||
| The first is to set a default consistency level for every execution of the | ||||
| prepared statement: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|     from cassandra import ConsistencyLevel | ||||
|  | ||||
|     cluster = Cluster() | ||||
|     session = cluster.connect("mykeyspace") | ||||
|     user_lookup_stmt = session.prepare("SELECT * FROM users WHERE user_id=?") | ||||
|     user_lookup_stmt.consistency_level = ConsistencyLevel.QUORUM | ||||
|  | ||||
|     # these will both use QUORUM | ||||
|     user1 = session.execute(user_lookup_stmt, [user_id1])[0] | ||||
|     user2 = session.execute(user_lookup_stmt, [user_id2])[0] | ||||
|  | ||||
| The second option is to create a :class:`~.BoundStatement` from the | ||||
| :class:`~.PreparedStatement` and binding paramaters and set a consistency | ||||
| level on that: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|     # override the QUORUM default | ||||
|     user3_lookup = user_lookup_stmt.bind([user_id3] | ||||
|     user3_lookup.consistency_level = ConsistencyLevel.ALL | ||||
|     user3 = session.execute(user3_lookup) | ||||
|   | ||||
| @@ -1,8 +1,6 @@ | ||||
| Python Cassandra Driver | ||||
| ======================= | ||||
|  | ||||
| Contents: | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 2 | ||||
|  | ||||
|   | ||||
| @@ -31,6 +31,7 @@ class LargeDataTests(unittest.TestCase): | ||||
|     def make_session_and_keyspace(self): | ||||
|         cluster = Cluster() | ||||
|         session = cluster.connect() | ||||
|         session.default_timeout = 20.0  # increase the default timeout | ||||
|         session.row_factory = dict_factory | ||||
|  | ||||
|         create_schema(session, self.keyspace) | ||||
|   | ||||
| @@ -138,8 +138,8 @@ class SchemaMetadataTest(unittest.TestCase): | ||||
|         self.assertEqual([], tablemeta.clustering_key) | ||||
|         self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) | ||||
|  | ||||
|         for option in TableMetadata.recognized_options: | ||||
|             self.assertTrue(option in tablemeta.options) | ||||
|         for option in tablemeta.options: | ||||
|             self.assertIn(option, TableMetadata.recognized_options) | ||||
|  | ||||
|         self.check_create_statement(tablemeta, create_statement) | ||||
|  | ||||
| @@ -295,7 +295,7 @@ class TestCodeCoverage(unittest.TestCase): | ||||
|         cluster = Cluster() | ||||
|         cluster.connect() | ||||
|  | ||||
|         self.assertIsInstance(cluster.metadata.export_schema_as_string(), unicode) | ||||
|         self.assertIsInstance(cluster.metadata.export_schema_as_string(), basestring) | ||||
|  | ||||
|     def test_export_keyspace_schema(self): | ||||
|         """ | ||||
| @@ -307,8 +307,47 @@ class TestCodeCoverage(unittest.TestCase): | ||||
|  | ||||
|         for keyspace in cluster.metadata.keyspaces: | ||||
|             keyspace_metadata = cluster.metadata.keyspaces[keyspace] | ||||
|             self.assertIsInstance(keyspace_metadata.export_as_string(), unicode) | ||||
|             self.assertIsInstance(keyspace_metadata.as_cql_query(), unicode) | ||||
|             self.assertIsInstance(keyspace_metadata.export_as_string(), basestring) | ||||
|             self.assertIsInstance(keyspace_metadata.as_cql_query(), basestring) | ||||
|  | ||||
|     def test_case_sensitivity(self): | ||||
|         """ | ||||
|         Test that names that need to be escaped in CREATE statements are | ||||
|         """ | ||||
|  | ||||
|         cluster = Cluster() | ||||
|         session = cluster.connect() | ||||
|  | ||||
|         ksname = 'AnInterestingKeyspace' | ||||
|         cfname = 'AnInterestingTable' | ||||
|  | ||||
|         session.execute(""" | ||||
|             CREATE KEYSPACE "%s" | ||||
|             WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} | ||||
|             """ % (ksname,)) | ||||
|         session.execute(""" | ||||
|             CREATE TABLE "%s"."%s" ( | ||||
|                 k int, | ||||
|                 "A" int, | ||||
|                 "B" int, | ||||
|                 "MyColumn" int, | ||||
|                 PRIMARY KEY (k, "A")) | ||||
|             WITH CLUSTERING ORDER BY ("A" DESC) | ||||
|             """ % (ksname, cfname)) | ||||
|         session.execute(""" | ||||
|             CREATE INDEX myindex ON "%s"."%s" ("MyColumn") | ||||
|             """ % (ksname, cfname)) | ||||
|  | ||||
|         ksmeta = cluster.metadata.keyspaces[ksname] | ||||
|         schema = ksmeta.export_as_string() | ||||
|         self.assertIn('CREATE KEYSPACE "AnInterestingKeyspace"', schema) | ||||
|         self.assertIn('CREATE TABLE "AnInterestingKeyspace"."AnInterestingTable"', schema) | ||||
|         self.assertIn('"A" int', schema) | ||||
|         self.assertIn('"B" int', schema) | ||||
|         self.assertIn('"MyColumn" int', schema) | ||||
|         self.assertIn('PRIMARY KEY (k, "A")', schema) | ||||
|         self.assertIn('WITH CLUSTERING ORDER BY ("A" DESC)', schema) | ||||
|         self.assertIn('CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")', schema) | ||||
|  | ||||
|     def test_already_exists_exceptions(self): | ||||
|         """ | ||||
| @@ -365,8 +404,8 @@ class TestCodeCoverage(unittest.TestCase): | ||||
|  | ||||
|         for i, token in enumerate(ring): | ||||
|             self.assertEqual(set(get_replicas('test3rf', token)), set(owners)) | ||||
|             self.assertEqual(set(get_replicas('test2rf', token)), set([owners[i], owners[(i + 1) % 3]])) | ||||
|             self.assertEqual(set(get_replicas('test1rf', token)), set([owners[i]])) | ||||
|             self.assertEqual(set(get_replicas('test2rf', token)), set([owners[(i + 1) % 3], owners[(i + 2) % 3]])) | ||||
|             self.assertEqual(set(get_replicas('test1rf', token)), set([owners[(i + 1) % 3]])) | ||||
|  | ||||
|  | ||||
| class TokenMetadataTest(unittest.TestCase): | ||||
| @@ -393,12 +432,13 @@ class TokenMetadataTest(unittest.TestCase): | ||||
|         token_map = TokenMap(MD5Token, token_to_primary_replica, tokens, metadata) | ||||
|  | ||||
|         # tokens match node tokens exactly | ||||
|         for token, expected_host in zip(tokens, hosts): | ||||
|         for i, token in enumerate(tokens): | ||||
|             expected_host = hosts[(i + 1) % len(hosts)] | ||||
|             replicas = token_map.get_replicas("ks", token) | ||||
|             self.assertEqual(set(replicas), set([expected_host])) | ||||
|  | ||||
|         # shift the tokens back by one | ||||
|         for token, expected_host in zip(tokens[1:], hosts[1:]): | ||||
|         for token, expected_host in zip(tokens, hosts): | ||||
|             replicas = token_map.get_replicas("ks", MD5Token(str(token.value - 1))) | ||||
|             self.assertEqual(set(replicas), set([expected_host])) | ||||
|  | ||||
|   | ||||
| @@ -5,6 +5,7 @@ except ImportError: | ||||
|  | ||||
| from cassandra.query import PreparedStatement, BoundStatement, ValueSequence, SimpleStatement | ||||
| from cassandra.cluster import Cluster | ||||
| from cassandra.decoder import dict_factory | ||||
|  | ||||
|  | ||||
| class QueryTest(unittest.TestCase): | ||||
| @@ -49,6 +50,20 @@ class QueryTest(unittest.TestCase): | ||||
|         for event in statement.trace.events: | ||||
|             str(event) | ||||
|  | ||||
|     def test_trace_ignores_row_factory(self): | ||||
|         cluster = Cluster() | ||||
|         session = cluster.connect() | ||||
|         session.row_factory = dict_factory | ||||
|  | ||||
|         query = "SELECT * FROM system.local" | ||||
|         statement = SimpleStatement(query) | ||||
|         session.execute(statement, trace=True) | ||||
|  | ||||
|         # Ensure this does not throw an exception | ||||
|         str(statement.trace) | ||||
|         for event in statement.trace.events: | ||||
|             str(event) | ||||
|  | ||||
|  | ||||
| class PreparedStatementTests(unittest.TestCase): | ||||
|  | ||||
|   | ||||
| @@ -14,6 +14,12 @@ except ImportError: | ||||
|  | ||||
| from cassandra import InvalidRequest | ||||
| from cassandra.cluster import Cluster | ||||
| from cassandra.cqltypes import Int32Type, EMPTY | ||||
| from cassandra.decoder import dict_factory | ||||
| try: | ||||
|     from collections import OrderedDict | ||||
| except ImportError: | ||||
|     from cassandra.util import OrderedDict  # noqa | ||||
|  | ||||
| from tests.integration import get_server_versions | ||||
|  | ||||
| @@ -100,16 +106,8 @@ class TypeTests(unittest.TestCase): | ||||
|         for expected, actual in zip(expected_vals, results[0]): | ||||
|             self.assertEquals(expected, actual) | ||||
|  | ||||
|     def test_basic_types(self): | ||||
|         c = Cluster() | ||||
|         s = c.connect() | ||||
|         s.execute(""" | ||||
|             CREATE KEYSPACE typetests | ||||
|             WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'} | ||||
|             """) | ||||
|         s.set_keyspace("typetests") | ||||
|         s.execute(""" | ||||
|             CREATE TABLE mytable ( | ||||
|     create_type_table = """ | ||||
|         CREATE TABLE mytable ( | ||||
|                 a text, | ||||
|                 b text, | ||||
|                 c ascii, | ||||
| @@ -131,7 +129,17 @@ class TypeTests(unittest.TestCase): | ||||
|                 t varint, | ||||
|                 PRIMARY KEY (a, b) | ||||
|             ) | ||||
|         """) | ||||
|         """ | ||||
|  | ||||
|     def test_basic_types(self): | ||||
|         c = Cluster() | ||||
|         s = c.connect() | ||||
|         s.execute(""" | ||||
|             CREATE KEYSPACE typetests | ||||
|             WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'} | ||||
|             """) | ||||
|         s.set_keyspace("typetests") | ||||
|         s.execute(self.create_type_table) | ||||
|  | ||||
|         v1_uuid = uuid1() | ||||
|         v4_uuid = uuid4() | ||||
| @@ -220,6 +228,126 @@ class TypeTests(unittest.TestCase): | ||||
|         for expected, actual in zip(expected_vals, results[0]): | ||||
|             self.assertEquals(expected, actual) | ||||
|  | ||||
|     def test_empty_strings_and_nones(self): | ||||
|         c = Cluster() | ||||
|         s = c.connect() | ||||
|         s.execute(""" | ||||
|             CREATE KEYSPACE test_empty_strings_and_nones | ||||
|             WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'} | ||||
|             """) | ||||
|         s.set_keyspace("test_empty_strings_and_nones") | ||||
|         s.execute(self.create_type_table) | ||||
|  | ||||
|         s.execute("INSERT INTO mytable (a, b) VALUES ('a', 'b')") | ||||
|         s.row_factory = dict_factory | ||||
|         results = s.execute(""" | ||||
|             SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable | ||||
|             """) | ||||
|         self.assertTrue(all(x is None for x in results[0].values())) | ||||
|  | ||||
|         prepared = s.prepare(""" | ||||
|             SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable | ||||
|             """) | ||||
|         results = s.execute(prepared.bind(())) | ||||
|         self.assertTrue(all(x is None for x in results[0].values())) | ||||
|  | ||||
|         # insert empty strings for string-like fields and fetch them | ||||
|         s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)", | ||||
|                   ('', '', '', [''], {'': 3})) | ||||
|         self.assertEquals( | ||||
|             {'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})}, | ||||
|             s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0]) | ||||
|  | ||||
|         self.assertEquals( | ||||
|             {'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})}, | ||||
|             s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0]) | ||||
|  | ||||
|         # non-string types shouldn't accept empty strings | ||||
|         for col in ('d', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'q', 'r', 't'): | ||||
|             query = "INSERT INTO mytable (a, b, %s) VALUES ('a', 'b', %%s)" % (col, ) | ||||
|             try: | ||||
|                 s.execute(query, ['']) | ||||
|             except InvalidRequest: | ||||
|                 pass | ||||
|             else: | ||||
|                 self.fail("Expected an InvalidRequest error when inserting an " | ||||
|                           "emptry string for column %s" % (col, )) | ||||
|  | ||||
|             prepared = s.prepare("INSERT INTO mytable (a, b, %s) VALUES ('a', 'b', ?)" % (col, )) | ||||
|             try: | ||||
|                 s.execute(prepared, ['']) | ||||
|             except TypeError: | ||||
|                 pass | ||||
|             else: | ||||
|                 self.fail("Expected an InvalidRequest error when inserting an " | ||||
|                           "emptry string for column %s with a prepared statement" % (col, )) | ||||
|  | ||||
|         # insert values for all columns | ||||
|         values = ['a', 'b', 'a', 1, True, Decimal('1.0'), 0.1, 0.1, | ||||
|                   "1.2.3.4", 1, ['a'], set([1]), {'a': 1}, 'a', | ||||
|                   datetime.now(), uuid4(), uuid1(), 'a', 1] | ||||
|         s.execute(""" | ||||
|             INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) | ||||
|             VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) | ||||
|             """, values) | ||||
|  | ||||
|         # then insert None, which should null them out | ||||
|         null_values = values[:2] + ([None] * (len(values) - 2)) | ||||
|         s.execute(""" | ||||
|             INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) | ||||
|             VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) | ||||
|             """, null_values) | ||||
|  | ||||
|         results = s.execute(""" | ||||
|             SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable | ||||
|             """) | ||||
|         self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) | ||||
|  | ||||
|         prepared = s.prepare(""" | ||||
|             SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable | ||||
|             """) | ||||
|         results = s.execute(prepared.bind(())) | ||||
|         self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) | ||||
|  | ||||
|         # do the same thing again, but use a prepared statement to insert the nulls | ||||
|         s.execute(""" | ||||
|             INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) | ||||
|             VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) | ||||
|             """, values) | ||||
|         prepared = s.prepare(""" | ||||
|             INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) | ||||
|             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | ||||
|             """) | ||||
|         s.execute(prepared, null_values) | ||||
|  | ||||
|         results = s.execute(""" | ||||
|             SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable | ||||
|             """) | ||||
|         self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) | ||||
|  | ||||
|         prepared = s.prepare(""" | ||||
|             SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable | ||||
|             """) | ||||
|         results = s.execute(prepared.bind(())) | ||||
|         self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) | ||||
|  | ||||
|     def test_empty_values(self): | ||||
|         c = Cluster() | ||||
|         s = c.connect() | ||||
|         s.execute(""" | ||||
|             CREATE KEYSPACE test_empty_values | ||||
|             WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'} | ||||
|             """) | ||||
|         s.set_keyspace("test_empty_values") | ||||
|         s.execute("CREATE TABLE mytable (a text PRIMARY KEY, b int)") | ||||
|         s.execute("INSERT INTO mytable (a, b) VALUES ('a', blobAsInt(0x))") | ||||
|         try: | ||||
|             Int32Type.support_empty_values = True | ||||
|             results = s.execute("SELECT b FROM mytable WHERE a='a'")[0] | ||||
|             self.assertIs(EMPTY, results.b) | ||||
|         finally: | ||||
|             Int32Type.support_empty_values = False | ||||
|  | ||||
|     def test_timezone_aware_datetimes(self): | ||||
|         """ Ensure timezone-aware datetimes are converted to timestamps correctly """ | ||||
|         try: | ||||
|   | ||||
| @@ -4,6 +4,7 @@ except ImportError: | ||||
|     import unittest # noqa | ||||
|  | ||||
| import errno | ||||
| import os | ||||
| from StringIO import StringIO | ||||
| import socket | ||||
| from socket import error as socket_error | ||||
| @@ -15,7 +16,7 @@ from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT, | ||||
|  | ||||
| from cassandra.decoder import (write_stringmultimap, write_int, write_string, | ||||
|                                SupportedMessage, ReadyMessage, ServerError) | ||||
| from cassandra.marshal import uint8_pack, uint32_pack | ||||
| from cassandra.marshal import uint8_pack, uint32_pack, int32_pack | ||||
|  | ||||
| from cassandra.io.asyncorereactor import AsyncoreConnection | ||||
|  | ||||
| @@ -84,6 +85,40 @@ class AsyncoreConnectionTest(unittest.TestCase): | ||||
|         c.handle_read() | ||||
|  | ||||
|         self.assertTrue(c.connected_event.is_set()) | ||||
|         return c | ||||
|  | ||||
|     def test_egain_on_buffer_size(self, *args): | ||||
|         # get a connection that's already fully started | ||||
|         c = self.test_successful_connection() | ||||
|  | ||||
|         header = '\x00\x00\x00\x00' + int32_pack(20000) | ||||
|         responses = [ | ||||
|             header + ('a' * (4096 - len(header))), | ||||
|             'a' * 4096, | ||||
|             socket_error(errno.EAGAIN), | ||||
|             'a' * 100, | ||||
|             socket_error(errno.EAGAIN)] | ||||
|  | ||||
|         def side_effect(*args): | ||||
|             response = responses.pop(0) | ||||
|             if isinstance(response, socket_error): | ||||
|                 raise response | ||||
|             else: | ||||
|                 return response | ||||
|  | ||||
|         c.socket.recv.side_effect = side_effect | ||||
|         c.handle_read() | ||||
|         self.assertEquals(c._total_reqd_bytes, 20000 + len(header)) | ||||
|         # the EAGAIN prevents it from reading the last 100 bytes | ||||
|         c._iobuf.seek(0, os.SEEK_END) | ||||
|         pos = c._iobuf.tell() | ||||
|         self.assertEquals(pos, 4096 + 4096) | ||||
|  | ||||
|         # now tell it to read the last 100 bytes | ||||
|         c.handle_read() | ||||
|         c._iobuf.seek(0, os.SEEK_END) | ||||
|         pos = c._iobuf.tell() | ||||
|         self.assertEquals(pos, 4096 + 4096 + 100) | ||||
|  | ||||
|     def test_protocol_error(self, *args): | ||||
|         c = self.make_connection() | ||||
|   | ||||
| @@ -4,6 +4,7 @@ except ImportError: | ||||
|     import unittest # noqa | ||||
|  | ||||
| import errno | ||||
| import os | ||||
| from StringIO import StringIO | ||||
| from socket import error as socket_error | ||||
|  | ||||
| @@ -14,7 +15,7 @@ from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT, | ||||
|  | ||||
| from cassandra.decoder import (write_stringmultimap, write_int, write_string, | ||||
|                                SupportedMessage, ReadyMessage, ServerError) | ||||
| from cassandra.marshal import uint8_pack, uint32_pack | ||||
| from cassandra.marshal import uint8_pack, uint32_pack, int32_pack | ||||
|  | ||||
| try: | ||||
|     from cassandra.io.libevreactor import LibevConnection | ||||
| @@ -84,6 +85,40 @@ class LibevConnectionTest(unittest.TestCase): | ||||
|         c.handle_read(None, 0) | ||||
|  | ||||
|         self.assertTrue(c.connected_event.is_set()) | ||||
|         return c | ||||
|  | ||||
|     def test_egain_on_buffer_size(self, *args): | ||||
|         # get a connection that's already fully started | ||||
|         c = self.test_successful_connection() | ||||
|  | ||||
|         header = '\x00\x00\x00\x00' + int32_pack(20000) | ||||
|         responses = [ | ||||
|             header + ('a' * (4096 - len(header))), | ||||
|             'a' * 4096, | ||||
|             socket_error(errno.EAGAIN), | ||||
|             'a' * 100, | ||||
|             socket_error(errno.EAGAIN)] | ||||
|  | ||||
|         def side_effect(*args): | ||||
|             response = responses.pop(0) | ||||
|             if isinstance(response, socket_error): | ||||
|                 raise response | ||||
|             else: | ||||
|                 return response | ||||
|  | ||||
|         c._socket.recv.side_effect = side_effect | ||||
|         c.handle_read(None, 0) | ||||
|         self.assertEquals(c._total_reqd_bytes, 20000 + len(header)) | ||||
|         # the EAGAIN prevents it from reading the last 100 bytes | ||||
|         c._iobuf.seek(0, os.SEEK_END) | ||||
|         pos = c._iobuf.tell() | ||||
|         self.assertEquals(pos, 4096 + 4096) | ||||
|  | ||||
|         # now tell it to read the last 100 bytes | ||||
|         c.handle_read(None, 0) | ||||
|         c._iobuf.seek(0, os.SEEK_END) | ||||
|         pos = c._iobuf.tell() | ||||
|         self.assertEquals(pos, 4096 + 4096 + 100) | ||||
|  | ||||
|     def test_protocol_error(self, *args): | ||||
|         c = self.make_connection() | ||||
|   | ||||
| @@ -4,10 +4,11 @@ except ImportError: | ||||
|     import unittest # noqa | ||||
|  | ||||
| import cassandra | ||||
| from cassandra.metadata import (TableMetadata, Murmur3Token, MD5Token, | ||||
| from cassandra.metadata import (Murmur3Token, MD5Token, | ||||
|                                 BytesToken, ReplicationStrategy, | ||||
|                                 NetworkTopologyStrategy, SimpleStrategy, | ||||
|                                 LocalStrategy, NoMurmur3) | ||||
|                                 LocalStrategy, NoMurmur3, protect_name, | ||||
|                                 protect_names, protect_value, is_valid_name) | ||||
| from cassandra.policies import SimpleConvictionPolicy | ||||
| from cassandra.pool import Host | ||||
|  | ||||
| @@ -132,31 +133,25 @@ class TestStrategies(unittest.TestCase): | ||||
|         self.assertItemsEqual(rf3_replicas[MD5Token(200)], [host3, host1, host2]) | ||||
|  | ||||
|  | ||||
| class TestTokens(unittest.TestCase): | ||||
| class TestNameEscaping(unittest.TestCase): | ||||
|  | ||||
|     def test_protect_name(self): | ||||
|         """ | ||||
|         Test TableMetadata.protect_name output | ||||
|         Test cassandra.metadata.protect_name output | ||||
|         """ | ||||
|  | ||||
|         table_metadata = TableMetadata('ks_name', 'table_name') | ||||
|  | ||||
|         self.assertEqual(table_metadata.protect_name('tests'), 'tests') | ||||
|         self.assertEqual(table_metadata.protect_name('test\'s'), '"test\'s"') | ||||
|         self.assertEqual(table_metadata.protect_name('test\'s'), "\"test's\"") | ||||
|         self.assertEqual(table_metadata.protect_name('tests ?!@#$%^&*()'), '"tests ?!@#$%^&*()"') | ||||
|         self.assertEqual(table_metadata.protect_name('1'), '"1"') | ||||
|         self.assertEqual(table_metadata.protect_name('1test'), '"1test"') | ||||
|         self.assertEqual(protect_name('tests'), 'tests') | ||||
|         self.assertEqual(protect_name('test\'s'), '"test\'s"') | ||||
|         self.assertEqual(protect_name('test\'s'), "\"test's\"") | ||||
|         self.assertEqual(protect_name('tests ?!@#$%^&*()'), '"tests ?!@#$%^&*()"') | ||||
|         self.assertEqual(protect_name('1'), '"1"') | ||||
|         self.assertEqual(protect_name('1test'), '"1test"') | ||||
|  | ||||
|     def test_protect_names(self): | ||||
|         """ | ||||
|         Test TableMetadata.protect_names output | ||||
|         Test cassandra.metadata.protect_names output | ||||
|         """ | ||||
|  | ||||
|         table_metadata = TableMetadata('ks_name', 'table_name') | ||||
|  | ||||
|         self.assertEqual(table_metadata.protect_names(['tests']), ['tests']) | ||||
|         self.assertEqual(table_metadata.protect_names( | ||||
|         self.assertEqual(protect_names(['tests']), ['tests']) | ||||
|         self.assertEqual(protect_names( | ||||
|             [ | ||||
|                 'tests', | ||||
|                 'test\'s', | ||||
| @@ -172,36 +167,33 @@ class TestTokens(unittest.TestCase): | ||||
|  | ||||
|     def test_protect_value(self): | ||||
|         """ | ||||
|         Test TableMetadata.protect_value output | ||||
|         Test cassandra.metadata.protect_value output | ||||
|         """ | ||||
|  | ||||
|         table_metadata = TableMetadata('ks_name', 'table_name') | ||||
|  | ||||
|         self.assertEqual(table_metadata.protect_value(True), "True") | ||||
|         self.assertEqual(table_metadata.protect_value(False), "False") | ||||
|         self.assertEqual(table_metadata.protect_value(3.14), '3.14') | ||||
|         self.assertEqual(table_metadata.protect_value(3), '3') | ||||
|         self.assertEqual(table_metadata.protect_value('test'), "'test'") | ||||
|         self.assertEqual(table_metadata.protect_value('test\'s'), "'test''s'") | ||||
|         self.assertEqual(table_metadata.protect_value(None), 'NULL') | ||||
|         self.assertEqual(protect_value(True), "True") | ||||
|         self.assertEqual(protect_value(False), "False") | ||||
|         self.assertEqual(protect_value(3.14), '3.14') | ||||
|         self.assertEqual(protect_value(3), '3') | ||||
|         self.assertEqual(protect_value('test'), "'test'") | ||||
|         self.assertEqual(protect_value('test\'s'), "'test''s'") | ||||
|         self.assertEqual(protect_value(None), 'NULL') | ||||
|  | ||||
|     def test_is_valid_name(self): | ||||
|         """ | ||||
|         Test TableMetadata.is_valid_name output | ||||
|         Test cassandra.metadata.is_valid_name output | ||||
|         """ | ||||
|  | ||||
|         table_metadata = TableMetadata('ks_name', 'table_name') | ||||
|  | ||||
|         self.assertEqual(table_metadata.is_valid_name(None), False) | ||||
|         self.assertEqual(table_metadata.is_valid_name('test'), True) | ||||
|         self.assertEqual(table_metadata.is_valid_name('Test'), False) | ||||
|         self.assertEqual(table_metadata.is_valid_name('t_____1'), True) | ||||
|         self.assertEqual(table_metadata.is_valid_name('test1'), True) | ||||
|         self.assertEqual(table_metadata.is_valid_name('1test1'), False) | ||||
|         self.assertEqual(is_valid_name(None), False) | ||||
|         self.assertEqual(is_valid_name('test'), True) | ||||
|         self.assertEqual(is_valid_name('Test'), False) | ||||
|         self.assertEqual(is_valid_name('t_____1'), True) | ||||
|         self.assertEqual(is_valid_name('test1'), True) | ||||
|         self.assertEqual(is_valid_name('1test1'), False) | ||||
|  | ||||
|         non_valid_keywords = cassandra.metadata._keywords - cassandra.metadata._unreserved_keywords | ||||
|         for keyword in non_valid_keywords: | ||||
|             self.assertEqual(table_metadata.is_valid_name(keyword), False) | ||||
|             self.assertEqual(is_valid_name(keyword), False) | ||||
|  | ||||
|  | ||||
| class TestTokens(unittest.TestCase): | ||||
|  | ||||
|     def test_murmur3_tokens(self): | ||||
|         try: | ||||
|   | ||||
| @@ -5,7 +5,6 @@ except ImportError: | ||||
|  | ||||
| from cassandra.query import bind_params, ValueSequence | ||||
| from cassandra.query import PreparedStatement, BoundStatement | ||||
| from cassandra.query import InvalidParameterTypeError | ||||
| from cassandra.cqltypes import Int32Type | ||||
|  | ||||
| try: | ||||
| @@ -77,21 +76,12 @@ class BoundStatementTestCase(unittest.TestCase): | ||||
|  | ||||
|         values = ['nonint', 1] | ||||
|  | ||||
|         try: | ||||
|             bound_statement.bind(values) | ||||
|         except InvalidParameterTypeError as e: | ||||
|             self.assertEqual(e.col_name, 'foo1') | ||||
|             self.assertEqual(e.expected_type, Int32Type) | ||||
|             self.assertEqual(e.actual_type, str) | ||||
|         else: | ||||
|             self.fail('Passed invalid type but exception was not thrown') | ||||
|  | ||||
|         try: | ||||
|             bound_statement.bind(values) | ||||
|         except TypeError as e: | ||||
|             self.assertEqual(e.col_name, 'foo1') | ||||
|             self.assertEqual(e.expected_type, Int32Type) | ||||
|             self.assertEqual(e.actual_type, str) | ||||
|             self.assertIn('foo1', str(e)) | ||||
|             self.assertIn('Int32Type', str(e)) | ||||
|             self.assertIn('str', str(e)) | ||||
|         else: | ||||
|             self.fail('Passed invalid type but exception was not thrown') | ||||
|  | ||||
| @@ -99,9 +89,9 @@ class BoundStatementTestCase(unittest.TestCase): | ||||
|  | ||||
|         try: | ||||
|             bound_statement.bind(values) | ||||
|         except InvalidParameterTypeError as e: | ||||
|             self.assertEqual(e.col_name, 'foo2') | ||||
|             self.assertEqual(e.expected_type, Int32Type) | ||||
|             self.assertEqual(e.actual_type, list) | ||||
|         except TypeError as e: | ||||
|             self.assertIn('foo2', str(e)) | ||||
|             self.assertIn('Int32Type', str(e)) | ||||
|             self.assertIn('list', str(e)) | ||||
|         else: | ||||
|             self.fail('Passed invalid type but exception was not thrown') | ||||
|   | ||||
| @@ -5,6 +5,8 @@ except ImportError: | ||||
|  | ||||
| from itertools import islice, cycle | ||||
| from mock import Mock | ||||
| from random import randint | ||||
| import sys | ||||
| import struct | ||||
| from threading import Thread | ||||
|  | ||||
| @@ -88,11 +90,51 @@ class TestRoundRobinPolicy(unittest.TestCase): | ||||
|         map(lambda t: t.start(), threads) | ||||
|         map(lambda t: t.join(), threads) | ||||
|  | ||||
|     def test_thread_safety_during_modification(self): | ||||
|         hosts = range(100) | ||||
|         policy = RoundRobinPolicy() | ||||
|         policy.populate(None, hosts) | ||||
|  | ||||
|         errors = [] | ||||
|  | ||||
|         def check_query_plan(): | ||||
|             try: | ||||
|                 for i in xrange(100): | ||||
|                     list(policy.make_query_plan()) | ||||
|             except Exception as exc: | ||||
|                 errors.append(exc) | ||||
|  | ||||
|         def host_up(): | ||||
|             for i in xrange(1000): | ||||
|                 policy.on_up(randint(0, 99)) | ||||
|  | ||||
|         def host_down(): | ||||
|             for i in xrange(1000): | ||||
|                 policy.on_down(randint(0, 99)) | ||||
|  | ||||
|         threads = [] | ||||
|         for i in range(5): | ||||
|             threads.append(Thread(target=check_query_plan)) | ||||
|             threads.append(Thread(target=host_up)) | ||||
|             threads.append(Thread(target=host_down)) | ||||
|  | ||||
|         # make the GIL switch after every instruction, maximizing | ||||
|         # the chace of race conditions | ||||
|         original_interval = sys.getcheckinterval() | ||||
|         try: | ||||
|             sys.setcheckinterval(0) | ||||
|             map(lambda t: t.start(), threads) | ||||
|             map(lambda t: t.join(), threads) | ||||
|         finally: | ||||
|             sys.setcheckinterval(original_interval) | ||||
|  | ||||
|         if errors: | ||||
|             self.fail("Saw errors: %s" % (errors,)) | ||||
|  | ||||
|     def test_no_live_nodes(self): | ||||
|         """ | ||||
|         Ensure query plan for a downed cluster will execute without errors | ||||
|         """ | ||||
|  | ||||
|         hosts = [0, 1, 2, 3] | ||||
|         policy = RoundRobinPolicy() | ||||
|         policy.populate(None, hosts) | ||||
| @@ -408,6 +450,7 @@ class TokenAwarePolicyTest(unittest.TestCase): | ||||
|         qplan = list(policy.make_query_plan()) | ||||
|         self.assertEqual(qplan, []) | ||||
|  | ||||
|  | ||||
| class ConvictionPolicyTest(unittest.TestCase): | ||||
|     def test_not_implemented(self): | ||||
|         """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Tyler Hobbs
					Tyler Hobbs