Don't require proto version to be passed to bind()
This commit is contained in:
@@ -1168,7 +1168,7 @@ class Session(object):
|
||||
if isinstance(query, six.string_types):
|
||||
query = SimpleStatement(query)
|
||||
elif isinstance(query, PreparedStatement):
|
||||
query = query.bind(parameters, self._protocol_version)
|
||||
query = query.bind(parameters)
|
||||
|
||||
cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level
|
||||
fetch_size = query.fetch_size
|
||||
@@ -1246,7 +1246,8 @@ class Session(object):
|
||||
raise
|
||||
|
||||
prepared_statement = PreparedStatement.from_message(
|
||||
query_id, column_metadata, self.cluster.metadata, query, self.keyspace)
|
||||
query_id, column_metadata, self.cluster.metadata, query, self.keyspace,
|
||||
self._protocol_version)
|
||||
|
||||
host = future._current_host
|
||||
try:
|
||||
|
||||
@@ -558,7 +558,7 @@ class ResultMessage(_MessageType):
|
||||
elif kind == RESULT_KIND_PREPARED:
|
||||
results = cls.recv_results_prepared(f)
|
||||
elif kind == RESULT_KIND_SCHEMA_CHANGE:
|
||||
results = cls.recv_results_schema_change(f)
|
||||
results = cls.recv_results_schema_change(f, protocol_version)
|
||||
return cls(kind, results, paging_state)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -313,21 +313,25 @@ class PreparedStatement(object):
|
||||
consistency_level = None
|
||||
serial_consistency_level = None
|
||||
|
||||
_protocol_version = None
|
||||
|
||||
def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace,
|
||||
consistency_level=None, serial_consistency_level=None, fetch_size=None):
|
||||
protocol_version, consistency_level=None, serial_consistency_level=None,
|
||||
fetch_size=None):
|
||||
self.column_metadata = column_metadata
|
||||
self.query_id = query_id
|
||||
self.routing_key_indexes = routing_key_indexes
|
||||
self.query_string = query
|
||||
self.keyspace = keyspace
|
||||
self._protocol_version = protocol_version
|
||||
self.consistency_level = consistency_level
|
||||
self.serial_consistency_level = serial_consistency_level
|
||||
self.fetch_size = fetch_size
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace):
|
||||
def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace, protocol_version):
|
||||
if not column_metadata:
|
||||
return PreparedStatement(column_metadata, query_id, None, query, keyspace)
|
||||
return PreparedStatement(column_metadata, query_id, None, query, keyspace, protocol_version)
|
||||
|
||||
partition_key_columns = None
|
||||
routing_key_indexes = None
|
||||
@@ -350,15 +354,16 @@ class PreparedStatement(object):
|
||||
pass # we're missing a partition key component in the prepared
|
||||
# statement; just leave routing_key_indexes as None
|
||||
|
||||
return PreparedStatement(column_metadata, query_id, routing_key_indexes, query, keyspace)
|
||||
return PreparedStatement(column_metadata, query_id, routing_key_indexes,
|
||||
query, keyspace, protocol_version)
|
||||
|
||||
def bind(self, values, protocol_version):
|
||||
def bind(self, values):
|
||||
"""
|
||||
Creates and returns a :class:`BoundStatement` instance using `values`.
|
||||
The `values` parameter **must** be a sequence, such as a tuple or list,
|
||||
even if there is only one value to bind.
|
||||
"""
|
||||
return BoundStatement(self).bind(values, protocol_version)
|
||||
return BoundStatement(self).bind(values)
|
||||
|
||||
def __str__(self):
|
||||
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
|
||||
@@ -397,7 +402,7 @@ class BoundStatement(Statement):
|
||||
|
||||
Statement.__init__(self, *args, **kwargs)
|
||||
|
||||
def bind(self, values, protocol_version):
|
||||
def bind(self, values):
|
||||
"""
|
||||
Binds a sequence of values for the prepared statement parameters
|
||||
and returns this instance. Note that `values` *must* be:
|
||||
@@ -408,6 +413,8 @@ class BoundStatement(Statement):
|
||||
values = ()
|
||||
col_meta = self.prepared_statement.column_metadata
|
||||
|
||||
proto_version = self.prepared_statement._protocol_version
|
||||
|
||||
# special case for binding dicts
|
||||
if isinstance(values, dict):
|
||||
dict_values = values
|
||||
@@ -457,7 +464,7 @@ class BoundStatement(Statement):
|
||||
col_type = col_spec[-1]
|
||||
|
||||
try:
|
||||
self.values.append(col_type.serialize(value, protocol_version))
|
||||
self.values.append(col_type.serialize(value, proto_version))
|
||||
except (TypeError, struct.error):
|
||||
col_name = col_spec[2]
|
||||
expected_type = col_type
|
||||
|
||||
@@ -77,10 +77,12 @@ class ConnectionTest(object):
|
||||
else:
|
||||
conn.send_msg(
|
||||
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
|
||||
request_id=0,
|
||||
cb=partial(cb, count))
|
||||
|
||||
conn.send_msg(
|
||||
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
|
||||
request_id=0,
|
||||
cb=partial(cb, 0))
|
||||
event.wait()
|
||||
|
||||
@@ -102,6 +104,7 @@ class ConnectionTest(object):
|
||||
for i in range(100):
|
||||
conn.send_msg(
|
||||
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
|
||||
request_id=i,
|
||||
cb=partial(cb, responses, i))
|
||||
|
||||
event.wait()
|
||||
@@ -122,11 +125,13 @@ class ConnectionTest(object):
|
||||
else:
|
||||
conn.send_msg(
|
||||
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
|
||||
request_id=count,
|
||||
cb=partial(cb, event, conn, count))
|
||||
|
||||
for event, conn in zip(events, conns):
|
||||
conn.send_msg(
|
||||
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
|
||||
request_id=0,
|
||||
cb=partial(cb, event, conn, 0))
|
||||
|
||||
for event in events:
|
||||
@@ -153,7 +158,9 @@ class ConnectionTest(object):
|
||||
def send_msgs(all_responses, thread_responses):
|
||||
for i in range(num_requests_per_conn):
|
||||
qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
|
||||
conn.send_msg(qmsg, cb=partial(cb, all_responses, thread_responses, i))
|
||||
with conn.lock:
|
||||
request_id = conn.get_request_id()
|
||||
conn.send_msg(qmsg, request_id, cb=partial(cb, all_responses, thread_responses, i))
|
||||
|
||||
all_responses = []
|
||||
threads = []
|
||||
@@ -192,7 +199,9 @@ class ConnectionTest(object):
|
||||
thread_responses = [False] * num_requests_per_conn
|
||||
for i in range(num_requests_per_conn):
|
||||
qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
|
||||
conn.send_msg(qmsg, cb=partial(cb, conn, event, thread_responses, i))
|
||||
with conn.lock:
|
||||
request_id = conn.get_request_id()
|
||||
conn.send_msg(qmsg, request_id, cb=partial(cb, conn, event, thread_responses, i))
|
||||
|
||||
event.wait()
|
||||
|
||||
|
||||
@@ -83,13 +83,14 @@ class BoundStatementTestCase(unittest.TestCase):
|
||||
query_id=None,
|
||||
routing_key_indexes=[],
|
||||
query=None,
|
||||
keyspace=keyspace)
|
||||
keyspace=keyspace,
|
||||
protocol_version=2)
|
||||
bound_statement = BoundStatement(prepared_statement=prepared_statement)
|
||||
|
||||
values = ['nonint', 1]
|
||||
|
||||
try:
|
||||
bound_statement.bind(values, protocol_version=1)
|
||||
bound_statement.bind(values)
|
||||
except TypeError as e:
|
||||
self.assertIn('foo1', str(e))
|
||||
self.assertIn('Int32Type', str(e))
|
||||
@@ -100,7 +101,7 @@ class BoundStatementTestCase(unittest.TestCase):
|
||||
values = [1, ['1', '2']]
|
||||
|
||||
try:
|
||||
bound_statement.bind(values, protocol_version=1)
|
||||
bound_statement.bind(values)
|
||||
except TypeError as e:
|
||||
self.assertIn('foo2', str(e))
|
||||
self.assertIn('Int32Type', str(e))
|
||||
|
||||
Reference in New Issue
Block a user