Merge pull request #321 from datastax/317
PYTHON-317 - Distinguish NULL and UNSET in bound parameters, protocol v4+
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import # to enable import io from stdlib
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import socket
|
||||
from uuid import UUID
|
||||
@@ -49,6 +50,7 @@ class NotSupportedError(Exception):
|
||||
class InternalError(Exception):
|
||||
pass
|
||||
|
||||
ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type'])
|
||||
|
||||
HEADER_DIRECTION_FROM_CLIENT = 0x00
|
||||
HEADER_DIRECTION_TO_CLIENT = 0x80
|
||||
@@ -62,6 +64,7 @@ WARNING_FLAG = 0x08
|
||||
_message_types_by_name = {}
|
||||
_message_types_by_opcode = {}
|
||||
|
||||
_UNSET_VALUE = object()
|
||||
|
||||
class _RegisterMessageType(type):
|
||||
def __init__(cls, name, bases, dct):
|
||||
@@ -727,7 +730,7 @@ class ResultMessage(_MessageType):
|
||||
colcfname = read_string(f)
|
||||
colname = read_string(f)
|
||||
coltype = cls.read_type(f, user_type_map)
|
||||
column_metadata.append((colksname, colcfname, colname, coltype))
|
||||
column_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype))
|
||||
return column_metadata, pk_indexes
|
||||
|
||||
@classmethod
|
||||
@@ -1113,6 +1116,8 @@ def read_value(f):
|
||||
def write_value(f, v):
|
||||
if v is None:
|
||||
write_int(f, -1)
|
||||
elif v is _UNSET_VALUE:
|
||||
write_int(f, -2)
|
||||
else:
|
||||
write_int(f, len(v))
|
||||
f.write(v)
|
||||
|
||||
@@ -24,16 +24,30 @@ import re
|
||||
import struct
|
||||
import time
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
from cassandra import ConsistencyLevel, OperationTimedOut
|
||||
from cassandra.util import unix_time_from_uuid1
|
||||
from cassandra.encoder import Encoder
|
||||
import cassandra.encoder
|
||||
from cassandra.protocol import _UNSET_VALUE
|
||||
from cassandra.util import OrderedDict
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
UNSET_VALUE = _UNSET_VALUE
|
||||
"""
|
||||
Specifies an unset value when binding a prepared statement.
|
||||
|
||||
Unset values are ignored, allowing prepared statements to be used without specify
|
||||
|
||||
See https://issues.apache.org/jira/browse/CASSANDRA-7304 for further details on semantics.
|
||||
|
||||
.. versionadded:: 2.6.0
|
||||
|
||||
Only valid when using native protocol v4+
|
||||
"""
|
||||
|
||||
NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]')
|
||||
START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*')
|
||||
@@ -350,6 +364,7 @@ class PreparedStatement(object):
|
||||
keyspace = None # change to prepared_keyspace in major release
|
||||
|
||||
routing_key_indexes = None
|
||||
_routing_key_index_set = None
|
||||
|
||||
consistency_level = None
|
||||
serial_consistency_level = None
|
||||
@@ -377,18 +392,17 @@ class PreparedStatement(object):
|
||||
if pk_indexes:
|
||||
routing_key_indexes = pk_indexes
|
||||
else:
|
||||
partition_key_columns = None
|
||||
routing_key_indexes = None
|
||||
|
||||
ks_name, table_name, _, _ = column_metadata[0]
|
||||
ks_meta = cluster_metadata.keyspaces.get(ks_name)
|
||||
first_col = column_metadata[0]
|
||||
ks_meta = cluster_metadata.keyspaces.get(first_col.keyspace_name)
|
||||
if ks_meta:
|
||||
table_meta = ks_meta.tables.get(table_name)
|
||||
table_meta = ks_meta.tables.get(first_col.table_name)
|
||||
if table_meta:
|
||||
partition_key_columns = table_meta.partition_key
|
||||
|
||||
# make a map of {column_name: index} for each column in the statement
|
||||
statement_indexes = dict((c[2], i) for i, c in enumerate(column_metadata))
|
||||
statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata))
|
||||
|
||||
# a list of which indexes in the statement correspond to partition key items
|
||||
try:
|
||||
@@ -403,11 +417,16 @@ class PreparedStatement(object):
|
||||
def bind(self, values):
|
||||
"""
|
||||
Creates and returns a :class:`BoundStatement` instance using `values`.
|
||||
The `values` parameter **must** be a sequence, such as a tuple or list,
|
||||
even if there is only one value to bind.
|
||||
|
||||
See :meth:`BoundStatement.bind` for rules on input ``values``.
|
||||
"""
|
||||
return BoundStatement(self).bind(values)
|
||||
|
||||
def is_routing_key_index(self, i):
|
||||
if self._routing_key_index_set is None:
|
||||
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
|
||||
return i in self._routing_key_index_set
|
||||
|
||||
def __str__(self):
|
||||
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
|
||||
return (u'<PreparedStatement query="%s", consistency=%s>' %
|
||||
@@ -447,7 +466,7 @@ class BoundStatement(Statement):
|
||||
|
||||
meta = prepared_statement.column_metadata
|
||||
if meta:
|
||||
self.keyspace = meta[0][0]
|
||||
self.keyspace = meta[0].keyspace_name
|
||||
|
||||
Statement.__init__(self, *args, **kwargs)
|
||||
|
||||
@@ -455,82 +474,95 @@ class BoundStatement(Statement):
|
||||
"""
|
||||
Binds a sequence of values for the prepared statement parameters
|
||||
and returns this instance. Note that `values` *must* be:
|
||||
|
||||
* a sequence, even if you are only binding one value, or
|
||||
* a dict that relates 1-to-1 between dict keys and columns
|
||||
|
||||
.. versionchanged:: 2.6.0
|
||||
|
||||
:data:`~.UNSET_VALUE` was introduced. These can be bound as positional parameters
|
||||
in a sequence, or by name in a dict. Additionally, when using protocol v4+:
|
||||
|
||||
* short sequences will be extended to match bind parameters with UNSET_VALUE
|
||||
* names may be omitted from a dict with UNSET_VALUE implied.
|
||||
|
||||
"""
|
||||
if values is None:
|
||||
values = ()
|
||||
col_meta = self.prepared_statement.column_metadata
|
||||
|
||||
proto_version = self.prepared_statement.protocol_version
|
||||
col_meta = self.prepared_statement.column_metadata
|
||||
col_meta_len = len(col_meta)
|
||||
value_len = len(values)
|
||||
|
||||
# special case for binding dicts
|
||||
if isinstance(values, dict):
|
||||
dict_values = values
|
||||
unbound_values = values.copy()
|
||||
values = []
|
||||
|
||||
# sort values accordingly
|
||||
for col in col_meta:
|
||||
try:
|
||||
values.append(dict_values[col[2]])
|
||||
values.append(unbound_values.pop(col.name))
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
'Column name `%s` not found in bound dict.' %
|
||||
(col[2]))
|
||||
if proto_version >= 4:
|
||||
values.append(UNSET_VALUE)
|
||||
else:
|
||||
raise KeyError(
|
||||
'Column name `%s` not found in bound dict.' %
|
||||
(col.name))
|
||||
|
||||
# ensure a 1-to-1 dict keys to columns relationship
|
||||
if len(dict_values) != len(col_meta):
|
||||
# find expected columns
|
||||
columns = set()
|
||||
for col in col_meta:
|
||||
columns.add(col[2])
|
||||
value_len = len(values)
|
||||
|
||||
# generate error message
|
||||
if len(dict_values) > len(col_meta):
|
||||
difference = set(dict_values.keys()).difference(columns)
|
||||
msg = "Too many arguments provided to bind() (got %d, expected %d). " + \
|
||||
"Unexpected keys %s."
|
||||
else:
|
||||
difference = set(columns).difference(dict_values.keys())
|
||||
msg = "Too few arguments provided to bind() (got %d, expected %d). " + \
|
||||
"Expected keys %s."
|
||||
if unbound_values:
|
||||
raise ValueError("Unexpected arguments provided to bind(): %s" % unbound_values.keys())
|
||||
|
||||
# exit with error message
|
||||
msg = msg % (len(values), len(col_meta), difference)
|
||||
raise ValueError(msg)
|
||||
|
||||
if len(values) > len(col_meta):
|
||||
if value_len > col_meta_len:
|
||||
raise ValueError(
|
||||
"Too many arguments provided to bind() (got %d, expected %d)" %
|
||||
(len(values), len(col_meta)))
|
||||
|
||||
if self.prepared_statement.routing_key_indexes and \
|
||||
len(values) < len(self.prepared_statement.routing_key_indexes):
|
||||
# this is fail-fast for clarity pre-v4. When v4 can be assumed,
|
||||
# the error will be better reported when UNSET_VALUE is implicitly added.
|
||||
if proto_version < 4 and self.prepared_statement.routing_key_indexes and \
|
||||
value_len < len(self.prepared_statement.routing_key_indexes):
|
||||
raise ValueError(
|
||||
"Too few arguments provided to bind() (got %d, required %d for routing key)" %
|
||||
(len(values), len(self.prepared_statement.routing_key_indexes)))
|
||||
(value_len, len(self.prepared_statement.routing_key_indexes)))
|
||||
|
||||
self.raw_values = values
|
||||
self.values = []
|
||||
for value, col_spec in zip(values, col_meta):
|
||||
if value is None:
|
||||
self.values.append(None)
|
||||
elif value is UNSET_VALUE:
|
||||
if proto_version >= 4:
|
||||
self._append_unset_value()
|
||||
else:
|
||||
raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version)
|
||||
else:
|
||||
col_type = col_spec[-1]
|
||||
|
||||
try:
|
||||
self.values.append(col_type.serialize(value, proto_version))
|
||||
self.values.append(col_spec.type.serialize(value, proto_version))
|
||||
except (TypeError, struct.error) as exc:
|
||||
col_name = col_spec[2]
|
||||
expected_type = col_type
|
||||
actual_type = type(value)
|
||||
|
||||
message = ('Received an argument of invalid type for column "%s". '
|
||||
'Expected: %s, Got: %s; (%s)' % (col_name, expected_type, actual_type, exc))
|
||||
'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc))
|
||||
raise TypeError(message)
|
||||
|
||||
if proto_version >= 4:
|
||||
diff = col_meta_len - len(self.values)
|
||||
if diff:
|
||||
for _ in range(diff):
|
||||
self._append_unset_value()
|
||||
|
||||
return self
|
||||
|
||||
def _append_unset_value(self):
|
||||
next_index = len(self.values)
|
||||
if self.prepared_statement.is_routing_key_index(next_index):
|
||||
col_meta = self.prepared_statement.column_metadata[next_index]
|
||||
raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name)
|
||||
self.values.append(UNSET_VALUE)
|
||||
|
||||
@property
|
||||
def routing_key(self):
|
||||
if not self.prepared_statement.routing_key_indexes:
|
||||
|
||||
@@ -23,6 +23,9 @@
|
||||
.. autoclass:: BoundStatement
|
||||
:members:
|
||||
|
||||
.. autodata:: UNSET_VALUE
|
||||
:annotation:
|
||||
|
||||
.. autoclass:: BatchStatement (batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None)
|
||||
:members:
|
||||
|
||||
|
||||
@@ -22,12 +22,13 @@ except ImportError:
|
||||
from cassandra import InvalidRequest
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.query import PreparedStatement
|
||||
from cassandra.query import PreparedStatement, UNSET_VALUE
|
||||
|
||||
|
||||
def setup_module():
|
||||
use_singledc()
|
||||
|
||||
|
||||
class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
def test_basic(self):
|
||||
@@ -65,9 +66,9 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
session.execute(bound)
|
||||
|
||||
prepared = session.prepare(
|
||||
"""
|
||||
SELECT * FROM cf0 WHERE a=?
|
||||
""")
|
||||
"""
|
||||
SELECT * FROM cf0 WHERE a=?
|
||||
""")
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
|
||||
bound = prepared.bind(('a'))
|
||||
@@ -90,9 +91,9 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
session.execute(bound)
|
||||
|
||||
prepared = session.prepare(
|
||||
"""
|
||||
SELECT * FROM cf0 WHERE a=?
|
||||
""")
|
||||
"""
|
||||
SELECT * FROM cf0 WHERE a=?
|
||||
""")
|
||||
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
|
||||
@@ -163,7 +164,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
def test_too_many_bind_values_dicts(self):
|
||||
"""
|
||||
Ensure a ValueError is thrown when attempting to bind too many variables
|
||||
Ensure an error is thrown when attempting to bind the wrong values
|
||||
with dict bindings
|
||||
"""
|
||||
|
||||
@@ -172,15 +173,29 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
prepared = session.prepare(
|
||||
"""
|
||||
INSERT INTO test3rf.test (v) VALUES (?)
|
||||
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
|
||||
""")
|
||||
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2})
|
||||
|
||||
# too many values
|
||||
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2, 'v2': 3})
|
||||
|
||||
# right number, but one does not belong
|
||||
if PROTOCOL_VERSION < 4:
|
||||
# pre v4, the driver bails with key error when 'v' is found missing
|
||||
self.assertRaises(KeyError, prepared.bind, {'k': 1, 'v2': 3})
|
||||
else:
|
||||
# post v4, the driver uses UNSET_VALUE for 'v' and bails when 'v2' is unbound
|
||||
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v2': 3})
|
||||
|
||||
# also catch too few variables with dicts
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
self.assertRaises(KeyError, prepared.bind, {})
|
||||
if PROTOCOL_VERSION < 4:
|
||||
self.assertRaises(KeyError, prepared.bind, {})
|
||||
else:
|
||||
# post v4, the driver attempts to use UNSET_VALUE for unspecified keys
|
||||
self.assertRaises(ValueError, prepared.bind, {})
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
@@ -202,9 +217,9 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
session.execute(bound)
|
||||
|
||||
prepared = session.prepare(
|
||||
"""
|
||||
SELECT * FROM test3rf.test WHERE k=?
|
||||
""")
|
||||
"""
|
||||
SELECT * FROM test3rf.test WHERE k=?
|
||||
""")
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
|
||||
bound = prepared.bind((1,))
|
||||
@@ -213,6 +228,56 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
def test_unset_values(self):
|
||||
"""
|
||||
Test to validate that UNSET_VALUEs are bound, and have the expected effect
|
||||
|
||||
Prepare a statement and insert all values. Then follow with execute excluding
|
||||
parameters. Verify that the original values are unaffected.
|
||||
|
||||
@since 2.6.0
|
||||
|
||||
@jira_ticket PYTHON-317
|
||||
@expected_result UNSET_VALUE is implicitly added to bind parameters, and properly encoded, leving unset values unaffected.
|
||||
|
||||
@test_category prepared_statements:binding
|
||||
"""
|
||||
if PROTOCOL_VERSION < 4:
|
||||
raise unittest.SkipTest("Binding UNSET values is not supported in protocol version < 4")
|
||||
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
# table with at least two values so one can be used as a marker
|
||||
session.execute("CREATE TABLE IF NOT EXISTS test1rf.test_unset_values (k int PRIMARY KEY, v0 int, v1 int)")
|
||||
insert = session.prepare("INSERT INTO test1rf.test_unset_values (k, v0, v1) VALUES (?, ?, ?)")
|
||||
select = session.prepare("SELECT * FROM test1rf.test_unset_values WHERE k=?")
|
||||
|
||||
bind_expected = [
|
||||
# initial condition
|
||||
((0, 0, 0), (0, 0, 0)),
|
||||
# unset implicit
|
||||
((0, 1,), (0, 1, 0)),
|
||||
({'k': 0, 'v0': 2}, (0, 2, 0)),
|
||||
({'k': 0, 'v1': 1}, (0, 2, 1)),
|
||||
# unset explicit
|
||||
((0, 3, UNSET_VALUE), (0, 3, 1)),
|
||||
((0, UNSET_VALUE, 2), (0, 3, 2)),
|
||||
({'k': 0, 'v0': 4, 'v1': UNSET_VALUE}, (0, 4, 2)),
|
||||
({'k': 0, 'v0': UNSET_VALUE, 'v1': 3}, (0, 4, 3)),
|
||||
# nulls still work
|
||||
((0, None, None), (0, None, None)),
|
||||
]
|
||||
|
||||
for params, expected in bind_expected:
|
||||
session.execute(insert, params)
|
||||
results = session.execute(select, (0,))
|
||||
self.assertEqual(results[0], expected)
|
||||
|
||||
self.assertRaises(ValueError, session.execute, select, (UNSET_VALUE, 0, 0))
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
def test_no_meta(self):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
@@ -227,9 +292,9 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
session.execute(bound)
|
||||
|
||||
prepared = session.prepare(
|
||||
"""
|
||||
SELECT * FROM test3rf.test WHERE k=0
|
||||
""")
|
||||
"""
|
||||
SELECT * FROM test3rf.test WHERE k=0
|
||||
""")
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
|
||||
bound = prepared.bind(None)
|
||||
@@ -257,9 +322,9 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
session.execute(bound)
|
||||
|
||||
prepared = session.prepare(
|
||||
"""
|
||||
SELECT * FROM test3rf.test WHERE k=?
|
||||
""")
|
||||
"""
|
||||
SELECT * FROM test3rf.test WHERE k=?
|
||||
""")
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
|
||||
bound = prepared.bind({'k': 1})
|
||||
@@ -286,9 +351,9 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
future.result()
|
||||
|
||||
prepared = session.prepare(
|
||||
"""
|
||||
SELECT * FROM test3rf.test WHERE k=?
|
||||
""")
|
||||
"""
|
||||
SELECT * FROM test3rf.test WHERE k=?
|
||||
""")
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
|
||||
future = session.execute_async(prepared, (873,))
|
||||
@@ -315,9 +380,9 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
future.result()
|
||||
|
||||
prepared = session.prepare(
|
||||
"""
|
||||
SELECT * FROM test3rf.test WHERE k=?
|
||||
""")
|
||||
"""
|
||||
SELECT * FROM test3rf.test WHERE k=?
|
||||
""")
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
|
||||
future = session.execute_async(prepared, {'k': 873})
|
||||
|
||||
@@ -18,8 +18,9 @@ except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.encoder import Encoder
|
||||
from cassandra.query import bind_params, ValueSequence
|
||||
from cassandra.query import PreparedStatement, BoundStatement
|
||||
from cassandra.protocol import ColumnMetadata
|
||||
from cassandra.query import (bind_params, ValueSequence, PreparedStatement,
|
||||
BoundStatement, UNSET_VALUE)
|
||||
from cassandra.cqltypes import Int32Type
|
||||
from cassandra.util import OrderedDict
|
||||
|
||||
@@ -73,42 +74,42 @@ class ParamBindingTest(unittest.TestCase):
|
||||
self.assertEqual(float(bind_params("%s", (f,), Encoder())), f)
|
||||
|
||||
|
||||
class BoundStatementTestCase(unittest.TestCase):
|
||||
class BoundStatementTestV1(unittest.TestCase):
|
||||
|
||||
protocol_version=1
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.prepared = PreparedStatement(column_metadata=[
|
||||
ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type),
|
||||
ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type),
|
||||
ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type),
|
||||
ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)
|
||||
],
|
||||
query_id=None,
|
||||
routing_key_indexes=[1, 0],
|
||||
query=None,
|
||||
keyspace='keyspace',
|
||||
protocol_version=cls.protocol_version)
|
||||
cls.bound = BoundStatement(prepared_statement=cls.prepared)
|
||||
|
||||
def test_invalid_argument_type(self):
|
||||
keyspace = 'keyspace1'
|
||||
column_family = 'cf1'
|
||||
|
||||
column_metadata = [
|
||||
(keyspace, column_family, 'foo1', Int32Type),
|
||||
(keyspace, column_family, 'foo2', Int32Type)
|
||||
]
|
||||
|
||||
prepared_statement = PreparedStatement(column_metadata=column_metadata,
|
||||
query_id=None,
|
||||
routing_key_indexes=[],
|
||||
query=None,
|
||||
keyspace=keyspace,
|
||||
protocol_version=2)
|
||||
bound_statement = BoundStatement(prepared_statement=prepared_statement)
|
||||
|
||||
values = ['nonint', 1]
|
||||
|
||||
values = (0, 0, 0, 'string not int')
|
||||
try:
|
||||
bound_statement.bind(values)
|
||||
self.bound.bind(values)
|
||||
except TypeError as e:
|
||||
self.assertIn('foo1', str(e))
|
||||
self.assertIn('v0', str(e))
|
||||
self.assertIn('Int32Type', str(e))
|
||||
self.assertIn('str', str(e))
|
||||
else:
|
||||
self.fail('Passed invalid type but exception was not thrown')
|
||||
|
||||
values = [1, ['1', '2']]
|
||||
values = (['1', '2'], 0, 0, 0)
|
||||
|
||||
try:
|
||||
bound_statement.bind(values)
|
||||
self.bound.bind(values)
|
||||
except TypeError as e:
|
||||
self.assertIn('foo2', str(e))
|
||||
self.assertIn('rk0', str(e))
|
||||
self.assertIn('Int32Type', str(e))
|
||||
self.assertIn('list', str(e))
|
||||
else:
|
||||
@@ -119,8 +120,8 @@ class BoundStatementTestCase(unittest.TestCase):
|
||||
column_family = 'cf1'
|
||||
|
||||
column_metadata = [
|
||||
(keyspace, column_family, 'foo1', Int32Type),
|
||||
(keyspace, column_family, 'foo2', Int32Type)
|
||||
ColumnMetadata(keyspace, column_family, 'foo1', Int32Type),
|
||||
ColumnMetadata(keyspace, column_family, 'foo2', Int32Type)
|
||||
]
|
||||
|
||||
prepared_statement = PreparedStatement(column_metadata=column_metadata,
|
||||
@@ -128,28 +129,87 @@ class BoundStatementTestCase(unittest.TestCase):
|
||||
routing_key_indexes=[],
|
||||
query=None,
|
||||
keyspace=keyspace,
|
||||
protocol_version=2)
|
||||
protocol_version=self.protocol_version)
|
||||
prepared_statement.fetch_size = 1234
|
||||
bound_statement = BoundStatement(prepared_statement=prepared_statement)
|
||||
self.assertEqual(1234, bound_statement.fetch_size)
|
||||
|
||||
def test_too_few_parameters_for_key(self):
|
||||
keyspace = 'keyspace1'
|
||||
column_family = 'cf1'
|
||||
def test_too_few_parameters_for_routing_key(self):
|
||||
self.assertRaises(ValueError, self.prepared.bind, (1,))
|
||||
|
||||
column_metadata = [
|
||||
(keyspace, column_family, 'foo1', Int32Type),
|
||||
(keyspace, column_family, 'foo2', Int32Type)
|
||||
]
|
||||
bound = self.prepared.bind((1, 2))
|
||||
self.assertEqual(bound.keyspace, 'keyspace')
|
||||
|
||||
prepared_statement = PreparedStatement(column_metadata=column_metadata,
|
||||
def test_dict_missing_routing_key(self):
|
||||
self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0})
|
||||
self.assertRaises(KeyError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0})
|
||||
|
||||
def test_missing_value(self):
|
||||
self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0})
|
||||
|
||||
def test_dict_extra_value(self):
|
||||
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123})
|
||||
self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, 0, 123))
|
||||
|
||||
def test_values_none(self):
|
||||
# should have values
|
||||
self.assertRaises(ValueError, self.bound.bind, None)
|
||||
|
||||
# prepared statement with no values
|
||||
prepared_statement = PreparedStatement(column_metadata=[],
|
||||
query_id=None,
|
||||
routing_key_indexes=[0, 1],
|
||||
routing_key_indexes=[],
|
||||
query=None,
|
||||
keyspace=keyspace,
|
||||
protocol_version=2)
|
||||
keyspace='whatever',
|
||||
protocol_version=self.protocol_version)
|
||||
bound = prepared_statement.bind(None)
|
||||
self.assertListEqual(bound.values, [])
|
||||
|
||||
self.assertRaises(ValueError, prepared_statement.bind, (1,))
|
||||
def test_bind_none(self):
|
||||
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None})
|
||||
self.assertEqual(self.bound.values[-1], None)
|
||||
|
||||
bound = prepared_statement.bind((1, 2))
|
||||
self.assertEqual(bound.keyspace, keyspace)
|
||||
old_values = self.bound.values
|
||||
self.bound.bind((0, 0, 0, None))
|
||||
self.assertIsNot(self.bound.values, old_values)
|
||||
self.assertEqual(self.bound.values[-1], None)
|
||||
|
||||
def test_unset_value(self):
|
||||
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
|
||||
self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, UNSET_VALUE))
|
||||
|
||||
|
||||
class BoundStatementTestV2(BoundStatementTestV1):
|
||||
protocol_version=2
|
||||
|
||||
|
||||
class BoundStatementTestV3(BoundStatementTestV1):
|
||||
protocol_version=3
|
||||
|
||||
|
||||
class BoundStatementTestV4(BoundStatementTestV1):
|
||||
protocol_version=4
|
||||
|
||||
def test_dict_missing_routing_key(self):
|
||||
# in v4 it implicitly binds UNSET_VALUE for missing items,
|
||||
# UNSET_VALUE is ValueError for routing keys
|
||||
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0})
|
||||
self.assertRaises(ValueError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0})
|
||||
|
||||
def test_missing_value(self):
|
||||
# in v4 missing values are UNSET_VALUE
|
||||
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0})
|
||||
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
|
||||
|
||||
old_values = self.bound.values
|
||||
self.bound.bind((0, 0, 0))
|
||||
self.assertIsNot(self.bound.values, old_values)
|
||||
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
|
||||
|
||||
def test_unset_value(self):
|
||||
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
|
||||
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
|
||||
|
||||
old_values = self.bound.values
|
||||
self.bound.bind((0, 0, 0, UNSET_VALUE))
|
||||
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
|
||||
|
||||
Reference in New Issue
Block a user