Merge pull request #321 from datastax/317

PYTHON-317 - Distinguish NULL and UNSET in bound parameters, protocol v4+
This commit is contained in:
Adam Holmberg
2015-05-26 16:29:08 -05:00
5 changed files with 280 additions and 115 deletions

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import # to enable import io from stdlib from __future__ import absolute_import # to enable import io from stdlib
from collections import namedtuple
import logging import logging
import socket import socket
from uuid import UUID from uuid import UUID
@@ -49,6 +50,7 @@ class NotSupportedError(Exception):
class InternalError(Exception): class InternalError(Exception):
pass pass
ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type'])
HEADER_DIRECTION_FROM_CLIENT = 0x00 HEADER_DIRECTION_FROM_CLIENT = 0x00
HEADER_DIRECTION_TO_CLIENT = 0x80 HEADER_DIRECTION_TO_CLIENT = 0x80
@@ -62,6 +64,7 @@ WARNING_FLAG = 0x08
_message_types_by_name = {} _message_types_by_name = {}
_message_types_by_opcode = {} _message_types_by_opcode = {}
_UNSET_VALUE = object()
class _RegisterMessageType(type): class _RegisterMessageType(type):
def __init__(cls, name, bases, dct): def __init__(cls, name, bases, dct):
@@ -727,7 +730,7 @@ class ResultMessage(_MessageType):
colcfname = read_string(f) colcfname = read_string(f)
colname = read_string(f) colname = read_string(f)
coltype = cls.read_type(f, user_type_map) coltype = cls.read_type(f, user_type_map)
column_metadata.append((colksname, colcfname, colname, coltype)) column_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype))
return column_metadata, pk_indexes return column_metadata, pk_indexes
@classmethod @classmethod
@@ -1113,6 +1116,8 @@ def read_value(f):
def write_value(f, v): def write_value(f, v):
if v is None: if v is None:
write_int(f, -1) write_int(f, -1)
elif v is _UNSET_VALUE:
write_int(f, -2)
else: else:
write_int(f, len(v)) write_int(f, len(v))
f.write(v) f.write(v)

View File

@@ -24,16 +24,30 @@ import re
import struct import struct
import time import time
import six import six
from six.moves import range
from cassandra import ConsistencyLevel, OperationTimedOut from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.util import unix_time_from_uuid1 from cassandra.util import unix_time_from_uuid1
from cassandra.encoder import Encoder from cassandra.encoder import Encoder
import cassandra.encoder import cassandra.encoder
from cassandra.protocol import _UNSET_VALUE
from cassandra.util import OrderedDict from cassandra.util import OrderedDict
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
UNSET_VALUE = _UNSET_VALUE
"""
Specifies an unset value when binding a prepared statement.
Unset values are ignored, allowing prepared statements to be used without specify
See https://issues.apache.org/jira/browse/CASSANDRA-7304 for further details on semantics.
.. versionadded:: 2.6.0
Only valid when using native protocol v4+
"""
NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]')
START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*')
@@ -350,6 +364,7 @@ class PreparedStatement(object):
keyspace = None # change to prepared_keyspace in major release keyspace = None # change to prepared_keyspace in major release
routing_key_indexes = None routing_key_indexes = None
_routing_key_index_set = None
consistency_level = None consistency_level = None
serial_consistency_level = None serial_consistency_level = None
@@ -377,18 +392,17 @@ class PreparedStatement(object):
if pk_indexes: if pk_indexes:
routing_key_indexes = pk_indexes routing_key_indexes = pk_indexes
else: else:
partition_key_columns = None
routing_key_indexes = None routing_key_indexes = None
ks_name, table_name, _, _ = column_metadata[0] first_col = column_metadata[0]
ks_meta = cluster_metadata.keyspaces.get(ks_name) ks_meta = cluster_metadata.keyspaces.get(first_col.keyspace_name)
if ks_meta: if ks_meta:
table_meta = ks_meta.tables.get(table_name) table_meta = ks_meta.tables.get(first_col.table_name)
if table_meta: if table_meta:
partition_key_columns = table_meta.partition_key partition_key_columns = table_meta.partition_key
# make a map of {column_name: index} for each column in the statement # make a map of {column_name: index} for each column in the statement
statement_indexes = dict((c[2], i) for i, c in enumerate(column_metadata)) statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata))
# a list of which indexes in the statement correspond to partition key items # a list of which indexes in the statement correspond to partition key items
try: try:
@@ -403,11 +417,16 @@ class PreparedStatement(object):
def bind(self, values): def bind(self, values):
""" """
Creates and returns a :class:`BoundStatement` instance using `values`. Creates and returns a :class:`BoundStatement` instance using `values`.
The `values` parameter **must** be a sequence, such as a tuple or list,
even if there is only one value to bind. See :meth:`BoundStatement.bind` for rules on input ``values``.
""" """
return BoundStatement(self).bind(values) return BoundStatement(self).bind(values)
def is_routing_key_index(self, i):
if self._routing_key_index_set is None:
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
return i in self._routing_key_index_set
def __str__(self): def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'<PreparedStatement query="%s", consistency=%s>' % return (u'<PreparedStatement query="%s", consistency=%s>' %
@@ -447,7 +466,7 @@ class BoundStatement(Statement):
meta = prepared_statement.column_metadata meta = prepared_statement.column_metadata
if meta: if meta:
self.keyspace = meta[0][0] self.keyspace = meta[0].keyspace_name
Statement.__init__(self, *args, **kwargs) Statement.__init__(self, *args, **kwargs)
@@ -455,82 +474,95 @@ class BoundStatement(Statement):
""" """
Binds a sequence of values for the prepared statement parameters Binds a sequence of values for the prepared statement parameters
and returns this instance. Note that `values` *must* be: and returns this instance. Note that `values` *must* be:
* a sequence, even if you are only binding one value, or * a sequence, even if you are only binding one value, or
* a dict that relates 1-to-1 between dict keys and columns * a dict that relates 1-to-1 between dict keys and columns
.. versionchanged:: 2.6.0
:data:`~.UNSET_VALUE` was introduced. These can be bound as positional parameters
in a sequence, or by name in a dict. Additionally, when using protocol v4+:
* short sequences will be extended to match bind parameters with UNSET_VALUE
* names may be omitted from a dict with UNSET_VALUE implied.
""" """
if values is None: if values is None:
values = () values = ()
col_meta = self.prepared_statement.column_metadata
proto_version = self.prepared_statement.protocol_version proto_version = self.prepared_statement.protocol_version
col_meta = self.prepared_statement.column_metadata
col_meta_len = len(col_meta)
value_len = len(values)
# special case for binding dicts # special case for binding dicts
if isinstance(values, dict): if isinstance(values, dict):
dict_values = values unbound_values = values.copy()
values = [] values = []
# sort values accordingly # sort values accordingly
for col in col_meta: for col in col_meta:
try: try:
values.append(dict_values[col[2]]) values.append(unbound_values.pop(col.name))
except KeyError: except KeyError:
raise KeyError( if proto_version >= 4:
'Column name `%s` not found in bound dict.' % values.append(UNSET_VALUE)
(col[2])) else:
raise KeyError(
'Column name `%s` not found in bound dict.' %
(col.name))
# ensure a 1-to-1 dict keys to columns relationship value_len = len(values)
if len(dict_values) != len(col_meta):
# find expected columns
columns = set()
for col in col_meta:
columns.add(col[2])
# generate error message if unbound_values:
if len(dict_values) > len(col_meta): raise ValueError("Unexpected arguments provided to bind(): %s" % unbound_values.keys())
difference = set(dict_values.keys()).difference(columns)
msg = "Too many arguments provided to bind() (got %d, expected %d). " + \
"Unexpected keys %s."
else:
difference = set(columns).difference(dict_values.keys())
msg = "Too few arguments provided to bind() (got %d, expected %d). " + \
"Expected keys %s."
# exit with error message if value_len > col_meta_len:
msg = msg % (len(values), len(col_meta), difference)
raise ValueError(msg)
if len(values) > len(col_meta):
raise ValueError( raise ValueError(
"Too many arguments provided to bind() (got %d, expected %d)" % "Too many arguments provided to bind() (got %d, expected %d)" %
(len(values), len(col_meta))) (len(values), len(col_meta)))
if self.prepared_statement.routing_key_indexes and \ # this is fail-fast for clarity pre-v4. When v4 can be assumed,
len(values) < len(self.prepared_statement.routing_key_indexes): # the error will be better reported when UNSET_VALUE is implicitly added.
if proto_version < 4 and self.prepared_statement.routing_key_indexes and \
value_len < len(self.prepared_statement.routing_key_indexes):
raise ValueError( raise ValueError(
"Too few arguments provided to bind() (got %d, required %d for routing key)" % "Too few arguments provided to bind() (got %d, required %d for routing key)" %
(len(values), len(self.prepared_statement.routing_key_indexes))) (value_len, len(self.prepared_statement.routing_key_indexes)))
self.raw_values = values self.raw_values = values
self.values = [] self.values = []
for value, col_spec in zip(values, col_meta): for value, col_spec in zip(values, col_meta):
if value is None: if value is None:
self.values.append(None) self.values.append(None)
elif value is UNSET_VALUE:
if proto_version >= 4:
self._append_unset_value()
else:
raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version)
else: else:
col_type = col_spec[-1]
try: try:
self.values.append(col_type.serialize(value, proto_version)) self.values.append(col_spec.type.serialize(value, proto_version))
except (TypeError, struct.error) as exc: except (TypeError, struct.error) as exc:
col_name = col_spec[2]
expected_type = col_type
actual_type = type(value) actual_type = type(value)
message = ('Received an argument of invalid type for column "%s". ' message = ('Received an argument of invalid type for column "%s". '
'Expected: %s, Got: %s; (%s)' % (col_name, expected_type, actual_type, exc)) 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc))
raise TypeError(message) raise TypeError(message)
if proto_version >= 4:
diff = col_meta_len - len(self.values)
if diff:
for _ in range(diff):
self._append_unset_value()
return self return self
def _append_unset_value(self):
next_index = len(self.values)
if self.prepared_statement.is_routing_key_index(next_index):
col_meta = self.prepared_statement.column_metadata[next_index]
raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name)
self.values.append(UNSET_VALUE)
@property @property
def routing_key(self): def routing_key(self):
if not self.prepared_statement.routing_key_indexes: if not self.prepared_statement.routing_key_indexes:

View File

@@ -23,6 +23,9 @@
.. autoclass:: BoundStatement .. autoclass:: BoundStatement
:members: :members:
.. autodata:: UNSET_VALUE
:annotation:
.. autoclass:: BatchStatement (batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None) .. autoclass:: BatchStatement (batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None)
:members: :members:

View File

@@ -22,12 +22,13 @@ except ImportError:
from cassandra import InvalidRequest from cassandra import InvalidRequest
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.query import PreparedStatement from cassandra.query import PreparedStatement, UNSET_VALUE
def setup_module(): def setup_module():
use_singledc() use_singledc()
class PreparedStatementTests(unittest.TestCase): class PreparedStatementTests(unittest.TestCase):
def test_basic(self): def test_basic(self):
@@ -65,9 +66,9 @@ class PreparedStatementTests(unittest.TestCase):
session.execute(bound) session.execute(bound)
prepared = session.prepare( prepared = session.prepare(
""" """
SELECT * FROM cf0 WHERE a=? SELECT * FROM cf0 WHERE a=?
""") """)
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind(('a')) bound = prepared.bind(('a'))
@@ -90,9 +91,9 @@ class PreparedStatementTests(unittest.TestCase):
session.execute(bound) session.execute(bound)
prepared = session.prepare( prepared = session.prepare(
""" """
SELECT * FROM cf0 WHERE a=? SELECT * FROM cf0 WHERE a=?
""") """)
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
@@ -163,7 +164,7 @@ class PreparedStatementTests(unittest.TestCase):
def test_too_many_bind_values_dicts(self): def test_too_many_bind_values_dicts(self):
""" """
Ensure a ValueError is thrown when attempting to bind too many variables Ensure an error is thrown when attempting to bind the wrong values
with dict bindings with dict bindings
""" """
@@ -172,15 +173,29 @@ class PreparedStatementTests(unittest.TestCase):
prepared = session.prepare( prepared = session.prepare(
""" """
INSERT INTO test3rf.test (v) VALUES (?) INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""") """)
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2})
# too many values
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2, 'v2': 3})
# right number, but one does not belong
if PROTOCOL_VERSION < 4:
# pre v4, the driver bails with key error when 'v' is found missing
self.assertRaises(KeyError, prepared.bind, {'k': 1, 'v2': 3})
else:
# post v4, the driver uses UNSET_VALUE for 'v' and bails when 'v2' is unbound
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v2': 3})
# also catch too few variables with dicts # also catch too few variables with dicts
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(KeyError, prepared.bind, {}) if PROTOCOL_VERSION < 4:
self.assertRaises(KeyError, prepared.bind, {})
else:
# post v4, the driver attempts to use UNSET_VALUE for unspecified keys
self.assertRaises(ValueError, prepared.bind, {})
cluster.shutdown() cluster.shutdown()
@@ -202,9 +217,9 @@ class PreparedStatementTests(unittest.TestCase):
session.execute(bound) session.execute(bound)
prepared = session.prepare( prepared = session.prepare(
""" """
SELECT * FROM test3rf.test WHERE k=? SELECT * FROM test3rf.test WHERE k=?
""") """)
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1,)) bound = prepared.bind((1,))
@@ -213,6 +228,56 @@ class PreparedStatementTests(unittest.TestCase):
cluster.shutdown() cluster.shutdown()
def test_unset_values(self):
"""
Test to validate that UNSET_VALUEs are bound, and have the expected effect
Prepare a statement and insert all values. Then follow with execute excluding
parameters. Verify that the original values are unaffected.
@since 2.6.0
@jira_ticket PYTHON-317
@expected_result UNSET_VALUE is implicitly added to bind parameters, and properly encoded, leving unset values unaffected.
@test_category prepared_statements:binding
"""
if PROTOCOL_VERSION < 4:
raise unittest.SkipTest("Binding UNSET values is not supported in protocol version < 4")
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
# table with at least two values so one can be used as a marker
session.execute("CREATE TABLE IF NOT EXISTS test1rf.test_unset_values (k int PRIMARY KEY, v0 int, v1 int)")
insert = session.prepare("INSERT INTO test1rf.test_unset_values (k, v0, v1) VALUES (?, ?, ?)")
select = session.prepare("SELECT * FROM test1rf.test_unset_values WHERE k=?")
bind_expected = [
# initial condition
((0, 0, 0), (0, 0, 0)),
# unset implicit
((0, 1,), (0, 1, 0)),
({'k': 0, 'v0': 2}, (0, 2, 0)),
({'k': 0, 'v1': 1}, (0, 2, 1)),
# unset explicit
((0, 3, UNSET_VALUE), (0, 3, 1)),
((0, UNSET_VALUE, 2), (0, 3, 2)),
({'k': 0, 'v0': 4, 'v1': UNSET_VALUE}, (0, 4, 2)),
({'k': 0, 'v0': UNSET_VALUE, 'v1': 3}, (0, 4, 3)),
# nulls still work
((0, None, None), (0, None, None)),
]
for params, expected in bind_expected:
session.execute(insert, params)
results = session.execute(select, (0,))
self.assertEqual(results[0], expected)
self.assertRaises(ValueError, session.execute, select, (UNSET_VALUE, 0, 0))
cluster.shutdown()
def test_no_meta(self): def test_no_meta(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION) cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
@@ -227,9 +292,9 @@ class PreparedStatementTests(unittest.TestCase):
session.execute(bound) session.execute(bound)
prepared = session.prepare( prepared = session.prepare(
""" """
SELECT * FROM test3rf.test WHERE k=0 SELECT * FROM test3rf.test WHERE k=0
""") """)
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind(None) bound = prepared.bind(None)
@@ -257,9 +322,9 @@ class PreparedStatementTests(unittest.TestCase):
session.execute(bound) session.execute(bound)
prepared = session.prepare( prepared = session.prepare(
""" """
SELECT * FROM test3rf.test WHERE k=? SELECT * FROM test3rf.test WHERE k=?
""") """)
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind({'k': 1}) bound = prepared.bind({'k': 1})
@@ -286,9 +351,9 @@ class PreparedStatementTests(unittest.TestCase):
future.result() future.result()
prepared = session.prepare( prepared = session.prepare(
""" """
SELECT * FROM test3rf.test WHERE k=? SELECT * FROM test3rf.test WHERE k=?
""") """)
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
future = session.execute_async(prepared, (873,)) future = session.execute_async(prepared, (873,))
@@ -315,9 +380,9 @@ class PreparedStatementTests(unittest.TestCase):
future.result() future.result()
prepared = session.prepare( prepared = session.prepare(
""" """
SELECT * FROM test3rf.test WHERE k=? SELECT * FROM test3rf.test WHERE k=?
""") """)
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
future = session.execute_async(prepared, {'k': 873}) future = session.execute_async(prepared, {'k': 873})

View File

@@ -18,8 +18,9 @@ except ImportError:
import unittest # noqa import unittest # noqa
from cassandra.encoder import Encoder from cassandra.encoder import Encoder
from cassandra.query import bind_params, ValueSequence from cassandra.protocol import ColumnMetadata
from cassandra.query import PreparedStatement, BoundStatement from cassandra.query import (bind_params, ValueSequence, PreparedStatement,
BoundStatement, UNSET_VALUE)
from cassandra.cqltypes import Int32Type from cassandra.cqltypes import Int32Type
from cassandra.util import OrderedDict from cassandra.util import OrderedDict
@@ -73,42 +74,42 @@ class ParamBindingTest(unittest.TestCase):
self.assertEqual(float(bind_params("%s", (f,), Encoder())), f) self.assertEqual(float(bind_params("%s", (f,), Encoder())), f)
class BoundStatementTestCase(unittest.TestCase): class BoundStatementTestV1(unittest.TestCase):
protocol_version=1
@classmethod
def setUpClass(cls):
cls.prepared = PreparedStatement(column_metadata=[
ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type),
ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type),
ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type),
ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)
],
query_id=None,
routing_key_indexes=[1, 0],
query=None,
keyspace='keyspace',
protocol_version=cls.protocol_version)
cls.bound = BoundStatement(prepared_statement=cls.prepared)
def test_invalid_argument_type(self): def test_invalid_argument_type(self):
keyspace = 'keyspace1' values = (0, 0, 0, 'string not int')
column_family = 'cf1'
column_metadata = [
(keyspace, column_family, 'foo1', Int32Type),
(keyspace, column_family, 'foo2', Int32Type)
]
prepared_statement = PreparedStatement(column_metadata=column_metadata,
query_id=None,
routing_key_indexes=[],
query=None,
keyspace=keyspace,
protocol_version=2)
bound_statement = BoundStatement(prepared_statement=prepared_statement)
values = ['nonint', 1]
try: try:
bound_statement.bind(values) self.bound.bind(values)
except TypeError as e: except TypeError as e:
self.assertIn('foo1', str(e)) self.assertIn('v0', str(e))
self.assertIn('Int32Type', str(e)) self.assertIn('Int32Type', str(e))
self.assertIn('str', str(e)) self.assertIn('str', str(e))
else: else:
self.fail('Passed invalid type but exception was not thrown') self.fail('Passed invalid type but exception was not thrown')
values = [1, ['1', '2']] values = (['1', '2'], 0, 0, 0)
try: try:
bound_statement.bind(values) self.bound.bind(values)
except TypeError as e: except TypeError as e:
self.assertIn('foo2', str(e)) self.assertIn('rk0', str(e))
self.assertIn('Int32Type', str(e)) self.assertIn('Int32Type', str(e))
self.assertIn('list', str(e)) self.assertIn('list', str(e))
else: else:
@@ -119,8 +120,8 @@ class BoundStatementTestCase(unittest.TestCase):
column_family = 'cf1' column_family = 'cf1'
column_metadata = [ column_metadata = [
(keyspace, column_family, 'foo1', Int32Type), ColumnMetadata(keyspace, column_family, 'foo1', Int32Type),
(keyspace, column_family, 'foo2', Int32Type) ColumnMetadata(keyspace, column_family, 'foo2', Int32Type)
] ]
prepared_statement = PreparedStatement(column_metadata=column_metadata, prepared_statement = PreparedStatement(column_metadata=column_metadata,
@@ -128,28 +129,87 @@ class BoundStatementTestCase(unittest.TestCase):
routing_key_indexes=[], routing_key_indexes=[],
query=None, query=None,
keyspace=keyspace, keyspace=keyspace,
protocol_version=2) protocol_version=self.protocol_version)
prepared_statement.fetch_size = 1234 prepared_statement.fetch_size = 1234
bound_statement = BoundStatement(prepared_statement=prepared_statement) bound_statement = BoundStatement(prepared_statement=prepared_statement)
self.assertEqual(1234, bound_statement.fetch_size) self.assertEqual(1234, bound_statement.fetch_size)
def test_too_few_parameters_for_key(self): def test_too_few_parameters_for_routing_key(self):
keyspace = 'keyspace1' self.assertRaises(ValueError, self.prepared.bind, (1,))
column_family = 'cf1'
column_metadata = [ bound = self.prepared.bind((1, 2))
(keyspace, column_family, 'foo1', Int32Type), self.assertEqual(bound.keyspace, 'keyspace')
(keyspace, column_family, 'foo2', Int32Type)
]
prepared_statement = PreparedStatement(column_metadata=column_metadata, def test_dict_missing_routing_key(self):
self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0})
self.assertRaises(KeyError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0})
def test_missing_value(self):
self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0})
def test_dict_extra_value(self):
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123})
self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, 0, 123))
def test_values_none(self):
# should have values
self.assertRaises(ValueError, self.bound.bind, None)
# prepared statement with no values
prepared_statement = PreparedStatement(column_metadata=[],
query_id=None, query_id=None,
routing_key_indexes=[0, 1], routing_key_indexes=[],
query=None, query=None,
keyspace=keyspace, keyspace='whatever',
protocol_version=2) protocol_version=self.protocol_version)
bound = prepared_statement.bind(None)
self.assertListEqual(bound.values, [])
self.assertRaises(ValueError, prepared_statement.bind, (1,)) def test_bind_none(self):
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None})
self.assertEqual(self.bound.values[-1], None)
bound = prepared_statement.bind((1, 2)) old_values = self.bound.values
self.assertEqual(bound.keyspace, keyspace) self.bound.bind((0, 0, 0, None))
self.assertIsNot(self.bound.values, old_values)
self.assertEqual(self.bound.values[-1], None)
def test_unset_value(self):
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, UNSET_VALUE))
class BoundStatementTestV2(BoundStatementTestV1):
protocol_version=2
class BoundStatementTestV3(BoundStatementTestV1):
protocol_version=3
class BoundStatementTestV4(BoundStatementTestV1):
protocol_version=4
def test_dict_missing_routing_key(self):
# in v4 it implicitly binds UNSET_VALUE for missing items,
# UNSET_VALUE is ValueError for routing keys
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0})
self.assertRaises(ValueError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0})
def test_missing_value(self):
# in v4 missing values are UNSET_VALUE
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0})
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
old_values = self.bound.values
self.bound.bind((0, 0, 0))
self.assertIsNot(self.bound.values, old_values)
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
def test_unset_value(self):
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
old_values = self.bound.values
self.bound.bind((0, 0, 0, UNSET_VALUE))
self.assertEqual(self.bound.values[-1], UNSET_VALUE)