diff --git a/cassandra/cluster.py b/cassandra/cluster.py index f9a60142..7451dd3d 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1168,7 +1168,7 @@ class Session(object): if isinstance(query, six.string_types): query = SimpleStatement(query) elif isinstance(query, PreparedStatement): - query = query.bind(parameters, self._protocol_version) + query = query.bind(parameters) cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level fetch_size = query.fetch_size @@ -1246,7 +1246,8 @@ class Session(object): raise prepared_statement = PreparedStatement.from_message( - query_id, column_metadata, self.cluster.metadata, query, self.keyspace) + query_id, column_metadata, self.cluster.metadata, query, self.keyspace, + self._protocol_version) host = future._current_host try: diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 99979c32..c37478be 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -558,7 +558,7 @@ class ResultMessage(_MessageType): elif kind == RESULT_KIND_PREPARED: results = cls.recv_results_prepared(f) elif kind == RESULT_KIND_SCHEMA_CHANGE: - results = cls.recv_results_schema_change(f) + results = cls.recv_results_schema_change(f, protocol_version) return cls(kind, results, paging_state) @classmethod diff --git a/cassandra/query.py b/cassandra/query.py index ce7a92e5..fcd24d26 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -313,21 +313,25 @@ class PreparedStatement(object): consistency_level = None serial_consistency_level = None + _protocol_version = None + def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace, - consistency_level=None, serial_consistency_level=None, fetch_size=None): + protocol_version, consistency_level=None, serial_consistency_level=None, + fetch_size=None): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes self.query_string = query self.keyspace = keyspace + self._protocol_version = protocol_version self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level self.fetch_size = fetch_size @classmethod - def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace): + def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace, protocol_version): if not column_metadata: - return PreparedStatement(column_metadata, query_id, None, query, keyspace) + return PreparedStatement(column_metadata, query_id, None, query, keyspace, protocol_version) partition_key_columns = None routing_key_indexes = None @@ -350,15 +354,16 @@ class PreparedStatement(object): pass # we're missing a partition key component in the prepared # statement; just leave routing_key_indexes as None - return PreparedStatement(column_metadata, query_id, routing_key_indexes, query, keyspace) + return PreparedStatement(column_metadata, query_id, routing_key_indexes, + query, keyspace, protocol_version) - def bind(self, values, protocol_version): + def bind(self, values): """ Creates and returns a :class:`BoundStatement` instance using `values`. The `values` parameter **must** be a sequence, such as a tuple or list, even if there is only one value to bind. """ - return BoundStatement(self).bind(values, protocol_version) + return BoundStatement(self).bind(values) def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') @@ -397,7 +402,7 @@ class BoundStatement(Statement): Statement.__init__(self, *args, **kwargs) - def bind(self, values, protocol_version): + def bind(self, values): """ Binds a sequence of values for the prepared statement parameters and returns this instance. Note that `values` *must* be: @@ -408,6 +413,8 @@ class BoundStatement(Statement): values = () col_meta = self.prepared_statement.column_metadata + proto_version = self.prepared_statement._protocol_version + # special case for binding dicts if isinstance(values, dict): dict_values = values @@ -457,7 +464,7 @@ class BoundStatement(Statement): col_type = col_spec[-1] try: - self.values.append(col_type.serialize(value, protocol_version)) + self.values.append(col_type.serialize(value, proto_version)) except (TypeError, struct.error): col_name = col_spec[2] expected_type = col_type diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 79b0fde1..fca1e6fa 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -77,10 +77,12 @@ class ConnectionTest(object): else: conn.send_msg( QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE), + request_id=0, cb=partial(cb, count)) conn.send_msg( QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE), + request_id=0, cb=partial(cb, 0)) event.wait() @@ -102,6 +104,7 @@ class ConnectionTest(object): for i in range(100): conn.send_msg( QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE), + request_id=i, cb=partial(cb, responses, i)) event.wait() @@ -122,11 +125,13 @@ class ConnectionTest(object): else: conn.send_msg( QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE), + request_id=count, cb=partial(cb, event, conn, count)) for event, conn in zip(events, conns): conn.send_msg( QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE), + request_id=0, cb=partial(cb, event, conn, 0)) for event in events: @@ -153,7 +158,9 @@ class ConnectionTest(object): def send_msgs(all_responses, thread_responses): for i in range(num_requests_per_conn): qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - conn.send_msg(qmsg, cb=partial(cb, all_responses, thread_responses, i)) + with conn.lock: + request_id = conn.get_request_id() + conn.send_msg(qmsg, request_id, cb=partial(cb, all_responses, thread_responses, i)) all_responses = [] threads = [] @@ -192,7 +199,9 @@ class ConnectionTest(object): thread_responses = [False] * num_requests_per_conn for i in range(num_requests_per_conn): qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - conn.send_msg(qmsg, cb=partial(cb, conn, event, thread_responses, i)) + with conn.lock: + request_id = conn.get_request_id() + conn.send_msg(qmsg, request_id, cb=partial(cb, conn, event, thread_responses, i)) event.wait() diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index fad61e2e..3f79c6b5 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -83,13 +83,14 @@ class BoundStatementTestCase(unittest.TestCase): query_id=None, routing_key_indexes=[], query=None, - keyspace=keyspace) + keyspace=keyspace, + protocol_version=2) bound_statement = BoundStatement(prepared_statement=prepared_statement) values = ['nonint', 1] try: - bound_statement.bind(values, protocol_version=1) + bound_statement.bind(values) except TypeError as e: self.assertIn('foo1', str(e)) self.assertIn('Int32Type', str(e)) @@ -100,7 +101,7 @@ class BoundStatementTestCase(unittest.TestCase): values = [1, ['1', '2']] try: - bound_statement.bind(values, protocol_version=1) + bound_statement.bind(values) except TypeError as e: self.assertIn('foo2', str(e)) self.assertIn('Int32Type', str(e))