Don't require proto version to be passed to bind()

This commit is contained in:
Tyler Hobbs
2014-05-30 16:18:37 -05:00
parent 98ef13169a
commit 375283feb8
5 changed files with 34 additions and 16 deletions

View File

@@ -1168,7 +1168,7 @@ class Session(object):
if isinstance(query, six.string_types):
query = SimpleStatement(query)
elif isinstance(query, PreparedStatement):
query = query.bind(parameters, self._protocol_version)
query = query.bind(parameters)
cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level
fetch_size = query.fetch_size
@@ -1246,7 +1246,8 @@ class Session(object):
raise
prepared_statement = PreparedStatement.from_message(
query_id, column_metadata, self.cluster.metadata, query, self.keyspace)
query_id, column_metadata, self.cluster.metadata, query, self.keyspace,
self._protocol_version)
host = future._current_host
try:

View File

@@ -558,7 +558,7 @@ class ResultMessage(_MessageType):
elif kind == RESULT_KIND_PREPARED:
results = cls.recv_results_prepared(f)
elif kind == RESULT_KIND_SCHEMA_CHANGE:
results = cls.recv_results_schema_change(f)
results = cls.recv_results_schema_change(f, protocol_version)
return cls(kind, results, paging_state)
@classmethod

View File

@@ -313,21 +313,25 @@ class PreparedStatement(object):
consistency_level = None
serial_consistency_level = None
_protocol_version = None
def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace,
consistency_level=None, serial_consistency_level=None, fetch_size=None):
protocol_version, consistency_level=None, serial_consistency_level=None,
fetch_size=None):
self.column_metadata = column_metadata
self.query_id = query_id
self.routing_key_indexes = routing_key_indexes
self.query_string = query
self.keyspace = keyspace
self._protocol_version = protocol_version
self.consistency_level = consistency_level
self.serial_consistency_level = serial_consistency_level
self.fetch_size = fetch_size
@classmethod
def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace):
def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace, protocol_version):
if not column_metadata:
return PreparedStatement(column_metadata, query_id, None, query, keyspace)
return PreparedStatement(column_metadata, query_id, None, query, keyspace, protocol_version)
partition_key_columns = None
routing_key_indexes = None
@@ -350,15 +354,16 @@ class PreparedStatement(object):
pass # we're missing a partition key component in the prepared
# statement; just leave routing_key_indexes as None
return PreparedStatement(column_metadata, query_id, routing_key_indexes, query, keyspace)
return PreparedStatement(column_metadata, query_id, routing_key_indexes,
query, keyspace, protocol_version)
def bind(self, values, protocol_version):
def bind(self, values):
"""
Creates and returns a :class:`BoundStatement` instance using `values`.
The `values` parameter **must** be a sequence, such as a tuple or list,
even if there is only one value to bind.
"""
return BoundStatement(self).bind(values, protocol_version)
return BoundStatement(self).bind(values)
def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
@@ -397,7 +402,7 @@ class BoundStatement(Statement):
Statement.__init__(self, *args, **kwargs)
def bind(self, values, protocol_version):
def bind(self, values):
"""
Binds a sequence of values for the prepared statement parameters
and returns this instance. Note that `values` *must* be:
@@ -408,6 +413,8 @@ class BoundStatement(Statement):
values = ()
col_meta = self.prepared_statement.column_metadata
proto_version = self.prepared_statement._protocol_version
# special case for binding dicts
if isinstance(values, dict):
dict_values = values
@@ -457,7 +464,7 @@ class BoundStatement(Statement):
col_type = col_spec[-1]
try:
self.values.append(col_type.serialize(value, protocol_version))
self.values.append(col_type.serialize(value, proto_version))
except (TypeError, struct.error):
col_name = col_spec[2]
expected_type = col_type

View File

@@ -77,10 +77,12 @@ class ConnectionTest(object):
else:
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
request_id=0,
cb=partial(cb, count))
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
request_id=0,
cb=partial(cb, 0))
event.wait()
@@ -102,6 +104,7 @@ class ConnectionTest(object):
for i in range(100):
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
request_id=i,
cb=partial(cb, responses, i))
event.wait()
@@ -122,11 +125,13 @@ class ConnectionTest(object):
else:
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
request_id=count,
cb=partial(cb, event, conn, count))
for event, conn in zip(events, conns):
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
request_id=0,
cb=partial(cb, event, conn, 0))
for event in events:
@@ -153,7 +158,9 @@ class ConnectionTest(object):
def send_msgs(all_responses, thread_responses):
for i in range(num_requests_per_conn):
qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
conn.send_msg(qmsg, cb=partial(cb, all_responses, thread_responses, i))
with conn.lock:
request_id = conn.get_request_id()
conn.send_msg(qmsg, request_id, cb=partial(cb, all_responses, thread_responses, i))
all_responses = []
threads = []
@@ -192,7 +199,9 @@ class ConnectionTest(object):
thread_responses = [False] * num_requests_per_conn
for i in range(num_requests_per_conn):
qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
conn.send_msg(qmsg, cb=partial(cb, conn, event, thread_responses, i))
with conn.lock:
request_id = conn.get_request_id()
conn.send_msg(qmsg, request_id, cb=partial(cb, conn, event, thread_responses, i))
event.wait()

View File

@@ -83,13 +83,14 @@ class BoundStatementTestCase(unittest.TestCase):
query_id=None,
routing_key_indexes=[],
query=None,
keyspace=keyspace)
keyspace=keyspace,
protocol_version=2)
bound_statement = BoundStatement(prepared_statement=prepared_statement)
values = ['nonint', 1]
try:
bound_statement.bind(values, protocol_version=1)
bound_statement.bind(values)
except TypeError as e:
self.assertIn('foo1', str(e))
self.assertIn('Int32Type', str(e))
@@ -100,7 +101,7 @@ class BoundStatementTestCase(unittest.TestCase):
values = [1, ['1', '2']]
try:
bound_statement.bind(values, protocol_version=1)
bound_statement.bind(values)
except TypeError as e:
self.assertIn('foo2', str(e))
self.assertIn('Int32Type', str(e))