Merge pull request #321 from datastax/317
PYTHON-317 - Distinguish NULL and UNSET in bound parameters, protocol v4+
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import absolute_import # to enable import io from stdlib
|
from __future__ import absolute_import # to enable import io from stdlib
|
||||||
|
from collections import namedtuple
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@@ -49,6 +50,7 @@ class NotSupportedError(Exception):
|
|||||||
class InternalError(Exception):
|
class InternalError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type'])
|
||||||
|
|
||||||
HEADER_DIRECTION_FROM_CLIENT = 0x00
|
HEADER_DIRECTION_FROM_CLIENT = 0x00
|
||||||
HEADER_DIRECTION_TO_CLIENT = 0x80
|
HEADER_DIRECTION_TO_CLIENT = 0x80
|
||||||
@@ -62,6 +64,7 @@ WARNING_FLAG = 0x08
|
|||||||
_message_types_by_name = {}
|
_message_types_by_name = {}
|
||||||
_message_types_by_opcode = {}
|
_message_types_by_opcode = {}
|
||||||
|
|
||||||
|
_UNSET_VALUE = object()
|
||||||
|
|
||||||
class _RegisterMessageType(type):
|
class _RegisterMessageType(type):
|
||||||
def __init__(cls, name, bases, dct):
|
def __init__(cls, name, bases, dct):
|
||||||
@@ -727,7 +730,7 @@ class ResultMessage(_MessageType):
|
|||||||
colcfname = read_string(f)
|
colcfname = read_string(f)
|
||||||
colname = read_string(f)
|
colname = read_string(f)
|
||||||
coltype = cls.read_type(f, user_type_map)
|
coltype = cls.read_type(f, user_type_map)
|
||||||
column_metadata.append((colksname, colcfname, colname, coltype))
|
column_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype))
|
||||||
return column_metadata, pk_indexes
|
return column_metadata, pk_indexes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1113,6 +1116,8 @@ def read_value(f):
|
|||||||
def write_value(f, v):
|
def write_value(f, v):
|
||||||
if v is None:
|
if v is None:
|
||||||
write_int(f, -1)
|
write_int(f, -1)
|
||||||
|
elif v is _UNSET_VALUE:
|
||||||
|
write_int(f, -2)
|
||||||
else:
|
else:
|
||||||
write_int(f, len(v))
|
write_int(f, len(v))
|
||||||
f.write(v)
|
f.write(v)
|
||||||
|
|||||||
@@ -24,16 +24,30 @@ import re
|
|||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
import six
|
import six
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
from cassandra import ConsistencyLevel, OperationTimedOut
|
from cassandra import ConsistencyLevel, OperationTimedOut
|
||||||
from cassandra.util import unix_time_from_uuid1
|
from cassandra.util import unix_time_from_uuid1
|
||||||
from cassandra.encoder import Encoder
|
from cassandra.encoder import Encoder
|
||||||
import cassandra.encoder
|
import cassandra.encoder
|
||||||
|
from cassandra.protocol import _UNSET_VALUE
|
||||||
from cassandra.util import OrderedDict
|
from cassandra.util import OrderedDict
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
UNSET_VALUE = _UNSET_VALUE
|
||||||
|
"""
|
||||||
|
Specifies an unset value when binding a prepared statement.
|
||||||
|
|
||||||
|
Unset values are ignored, allowing prepared statements to be used without specify
|
||||||
|
|
||||||
|
See https://issues.apache.org/jira/browse/CASSANDRA-7304 for further details on semantics.
|
||||||
|
|
||||||
|
.. versionadded:: 2.6.0
|
||||||
|
|
||||||
|
Only valid when using native protocol v4+
|
||||||
|
"""
|
||||||
|
|
||||||
NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]')
|
NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]')
|
||||||
START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*')
|
START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*')
|
||||||
@@ -350,6 +364,7 @@ class PreparedStatement(object):
|
|||||||
keyspace = None # change to prepared_keyspace in major release
|
keyspace = None # change to prepared_keyspace in major release
|
||||||
|
|
||||||
routing_key_indexes = None
|
routing_key_indexes = None
|
||||||
|
_routing_key_index_set = None
|
||||||
|
|
||||||
consistency_level = None
|
consistency_level = None
|
||||||
serial_consistency_level = None
|
serial_consistency_level = None
|
||||||
@@ -377,18 +392,17 @@ class PreparedStatement(object):
|
|||||||
if pk_indexes:
|
if pk_indexes:
|
||||||
routing_key_indexes = pk_indexes
|
routing_key_indexes = pk_indexes
|
||||||
else:
|
else:
|
||||||
partition_key_columns = None
|
|
||||||
routing_key_indexes = None
|
routing_key_indexes = None
|
||||||
|
|
||||||
ks_name, table_name, _, _ = column_metadata[0]
|
first_col = column_metadata[0]
|
||||||
ks_meta = cluster_metadata.keyspaces.get(ks_name)
|
ks_meta = cluster_metadata.keyspaces.get(first_col.keyspace_name)
|
||||||
if ks_meta:
|
if ks_meta:
|
||||||
table_meta = ks_meta.tables.get(table_name)
|
table_meta = ks_meta.tables.get(first_col.table_name)
|
||||||
if table_meta:
|
if table_meta:
|
||||||
partition_key_columns = table_meta.partition_key
|
partition_key_columns = table_meta.partition_key
|
||||||
|
|
||||||
# make a map of {column_name: index} for each column in the statement
|
# make a map of {column_name: index} for each column in the statement
|
||||||
statement_indexes = dict((c[2], i) for i, c in enumerate(column_metadata))
|
statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata))
|
||||||
|
|
||||||
# a list of which indexes in the statement correspond to partition key items
|
# a list of which indexes in the statement correspond to partition key items
|
||||||
try:
|
try:
|
||||||
@@ -403,11 +417,16 @@ class PreparedStatement(object):
|
|||||||
def bind(self, values):
|
def bind(self, values):
|
||||||
"""
|
"""
|
||||||
Creates and returns a :class:`BoundStatement` instance using `values`.
|
Creates and returns a :class:`BoundStatement` instance using `values`.
|
||||||
The `values` parameter **must** be a sequence, such as a tuple or list,
|
|
||||||
even if there is only one value to bind.
|
See :meth:`BoundStatement.bind` for rules on input ``values``.
|
||||||
"""
|
"""
|
||||||
return BoundStatement(self).bind(values)
|
return BoundStatement(self).bind(values)
|
||||||
|
|
||||||
|
def is_routing_key_index(self, i):
|
||||||
|
if self._routing_key_index_set is None:
|
||||||
|
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
|
||||||
|
return i in self._routing_key_index_set
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
|
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
|
||||||
return (u'<PreparedStatement query="%s", consistency=%s>' %
|
return (u'<PreparedStatement query="%s", consistency=%s>' %
|
||||||
@@ -447,7 +466,7 @@ class BoundStatement(Statement):
|
|||||||
|
|
||||||
meta = prepared_statement.column_metadata
|
meta = prepared_statement.column_metadata
|
||||||
if meta:
|
if meta:
|
||||||
self.keyspace = meta[0][0]
|
self.keyspace = meta[0].keyspace_name
|
||||||
|
|
||||||
Statement.__init__(self, *args, **kwargs)
|
Statement.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
@@ -455,82 +474,95 @@ class BoundStatement(Statement):
|
|||||||
"""
|
"""
|
||||||
Binds a sequence of values for the prepared statement parameters
|
Binds a sequence of values for the prepared statement parameters
|
||||||
and returns this instance. Note that `values` *must* be:
|
and returns this instance. Note that `values` *must* be:
|
||||||
|
|
||||||
* a sequence, even if you are only binding one value, or
|
* a sequence, even if you are only binding one value, or
|
||||||
* a dict that relates 1-to-1 between dict keys and columns
|
* a dict that relates 1-to-1 between dict keys and columns
|
||||||
|
|
||||||
|
.. versionchanged:: 2.6.0
|
||||||
|
|
||||||
|
:data:`~.UNSET_VALUE` was introduced. These can be bound as positional parameters
|
||||||
|
in a sequence, or by name in a dict. Additionally, when using protocol v4+:
|
||||||
|
|
||||||
|
* short sequences will be extended to match bind parameters with UNSET_VALUE
|
||||||
|
* names may be omitted from a dict with UNSET_VALUE implied.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if values is None:
|
if values is None:
|
||||||
values = ()
|
values = ()
|
||||||
col_meta = self.prepared_statement.column_metadata
|
|
||||||
|
|
||||||
proto_version = self.prepared_statement.protocol_version
|
proto_version = self.prepared_statement.protocol_version
|
||||||
|
col_meta = self.prepared_statement.column_metadata
|
||||||
|
col_meta_len = len(col_meta)
|
||||||
|
value_len = len(values)
|
||||||
|
|
||||||
# special case for binding dicts
|
# special case for binding dicts
|
||||||
if isinstance(values, dict):
|
if isinstance(values, dict):
|
||||||
dict_values = values
|
unbound_values = values.copy()
|
||||||
values = []
|
values = []
|
||||||
|
|
||||||
# sort values accordingly
|
# sort values accordingly
|
||||||
for col in col_meta:
|
for col in col_meta:
|
||||||
try:
|
try:
|
||||||
values.append(dict_values[col[2]])
|
values.append(unbound_values.pop(col.name))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise KeyError(
|
if proto_version >= 4:
|
||||||
'Column name `%s` not found in bound dict.' %
|
values.append(UNSET_VALUE)
|
||||||
(col[2]))
|
else:
|
||||||
|
raise KeyError(
|
||||||
|
'Column name `%s` not found in bound dict.' %
|
||||||
|
(col.name))
|
||||||
|
|
||||||
# ensure a 1-to-1 dict keys to columns relationship
|
value_len = len(values)
|
||||||
if len(dict_values) != len(col_meta):
|
|
||||||
# find expected columns
|
|
||||||
columns = set()
|
|
||||||
for col in col_meta:
|
|
||||||
columns.add(col[2])
|
|
||||||
|
|
||||||
# generate error message
|
if unbound_values:
|
||||||
if len(dict_values) > len(col_meta):
|
raise ValueError("Unexpected arguments provided to bind(): %s" % unbound_values.keys())
|
||||||
difference = set(dict_values.keys()).difference(columns)
|
|
||||||
msg = "Too many arguments provided to bind() (got %d, expected %d). " + \
|
|
||||||
"Unexpected keys %s."
|
|
||||||
else:
|
|
||||||
difference = set(columns).difference(dict_values.keys())
|
|
||||||
msg = "Too few arguments provided to bind() (got %d, expected %d). " + \
|
|
||||||
"Expected keys %s."
|
|
||||||
|
|
||||||
# exit with error message
|
if value_len > col_meta_len:
|
||||||
msg = msg % (len(values), len(col_meta), difference)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
if len(values) > len(col_meta):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Too many arguments provided to bind() (got %d, expected %d)" %
|
"Too many arguments provided to bind() (got %d, expected %d)" %
|
||||||
(len(values), len(col_meta)))
|
(len(values), len(col_meta)))
|
||||||
|
|
||||||
if self.prepared_statement.routing_key_indexes and \
|
# this is fail-fast for clarity pre-v4. When v4 can be assumed,
|
||||||
len(values) < len(self.prepared_statement.routing_key_indexes):
|
# the error will be better reported when UNSET_VALUE is implicitly added.
|
||||||
|
if proto_version < 4 and self.prepared_statement.routing_key_indexes and \
|
||||||
|
value_len < len(self.prepared_statement.routing_key_indexes):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Too few arguments provided to bind() (got %d, required %d for routing key)" %
|
"Too few arguments provided to bind() (got %d, required %d for routing key)" %
|
||||||
(len(values), len(self.prepared_statement.routing_key_indexes)))
|
(value_len, len(self.prepared_statement.routing_key_indexes)))
|
||||||
|
|
||||||
self.raw_values = values
|
self.raw_values = values
|
||||||
self.values = []
|
self.values = []
|
||||||
for value, col_spec in zip(values, col_meta):
|
for value, col_spec in zip(values, col_meta):
|
||||||
if value is None:
|
if value is None:
|
||||||
self.values.append(None)
|
self.values.append(None)
|
||||||
|
elif value is UNSET_VALUE:
|
||||||
|
if proto_version >= 4:
|
||||||
|
self._append_unset_value()
|
||||||
|
else:
|
||||||
|
raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version)
|
||||||
else:
|
else:
|
||||||
col_type = col_spec[-1]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.values.append(col_type.serialize(value, proto_version))
|
self.values.append(col_spec.type.serialize(value, proto_version))
|
||||||
except (TypeError, struct.error) as exc:
|
except (TypeError, struct.error) as exc:
|
||||||
col_name = col_spec[2]
|
|
||||||
expected_type = col_type
|
|
||||||
actual_type = type(value)
|
actual_type = type(value)
|
||||||
|
|
||||||
message = ('Received an argument of invalid type for column "%s". '
|
message = ('Received an argument of invalid type for column "%s". '
|
||||||
'Expected: %s, Got: %s; (%s)' % (col_name, expected_type, actual_type, exc))
|
'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc))
|
||||||
raise TypeError(message)
|
raise TypeError(message)
|
||||||
|
|
||||||
|
if proto_version >= 4:
|
||||||
|
diff = col_meta_len - len(self.values)
|
||||||
|
if diff:
|
||||||
|
for _ in range(diff):
|
||||||
|
self._append_unset_value()
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _append_unset_value(self):
|
||||||
|
next_index = len(self.values)
|
||||||
|
if self.prepared_statement.is_routing_key_index(next_index):
|
||||||
|
col_meta = self.prepared_statement.column_metadata[next_index]
|
||||||
|
raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name)
|
||||||
|
self.values.append(UNSET_VALUE)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def routing_key(self):
|
def routing_key(self):
|
||||||
if not self.prepared_statement.routing_key_indexes:
|
if not self.prepared_statement.routing_key_indexes:
|
||||||
|
|||||||
@@ -23,6 +23,9 @@
|
|||||||
.. autoclass:: BoundStatement
|
.. autoclass:: BoundStatement
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
.. autodata:: UNSET_VALUE
|
||||||
|
:annotation:
|
||||||
|
|
||||||
.. autoclass:: BatchStatement (batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None)
|
.. autoclass:: BatchStatement (batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None)
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|||||||
@@ -22,12 +22,13 @@ except ImportError:
|
|||||||
from cassandra import InvalidRequest
|
from cassandra import InvalidRequest
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.query import PreparedStatement
|
from cassandra.query import PreparedStatement, UNSET_VALUE
|
||||||
|
|
||||||
|
|
||||||
def setup_module():
|
def setup_module():
|
||||||
use_singledc()
|
use_singledc()
|
||||||
|
|
||||||
|
|
||||||
class PreparedStatementTests(unittest.TestCase):
|
class PreparedStatementTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
@@ -65,9 +66,9 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
session.execute(bound)
|
session.execute(bound)
|
||||||
|
|
||||||
prepared = session.prepare(
|
prepared = session.prepare(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM cf0 WHERE a=?
|
SELECT * FROM cf0 WHERE a=?
|
||||||
""")
|
""")
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
|
|
||||||
bound = prepared.bind(('a'))
|
bound = prepared.bind(('a'))
|
||||||
@@ -90,9 +91,9 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
session.execute(bound)
|
session.execute(bound)
|
||||||
|
|
||||||
prepared = session.prepare(
|
prepared = session.prepare(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM cf0 WHERE a=?
|
SELECT * FROM cf0 WHERE a=?
|
||||||
""")
|
""")
|
||||||
|
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
|
|
||||||
@@ -163,7 +164,7 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
def test_too_many_bind_values_dicts(self):
|
def test_too_many_bind_values_dicts(self):
|
||||||
"""
|
"""
|
||||||
Ensure a ValueError is thrown when attempting to bind too many variables
|
Ensure an error is thrown when attempting to bind the wrong values
|
||||||
with dict bindings
|
with dict bindings
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -172,15 +173,29 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
prepared = session.prepare(
|
prepared = session.prepare(
|
||||||
"""
|
"""
|
||||||
INSERT INTO test3rf.test (v) VALUES (?)
|
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2})
|
|
||||||
|
# too many values
|
||||||
|
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2, 'v2': 3})
|
||||||
|
|
||||||
|
# right number, but one does not belong
|
||||||
|
if PROTOCOL_VERSION < 4:
|
||||||
|
# pre v4, the driver bails with key error when 'v' is found missing
|
||||||
|
self.assertRaises(KeyError, prepared.bind, {'k': 1, 'v2': 3})
|
||||||
|
else:
|
||||||
|
# post v4, the driver uses UNSET_VALUE for 'v' and bails when 'v2' is unbound
|
||||||
|
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v2': 3})
|
||||||
|
|
||||||
# also catch too few variables with dicts
|
# also catch too few variables with dicts
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
self.assertRaises(KeyError, prepared.bind, {})
|
if PROTOCOL_VERSION < 4:
|
||||||
|
self.assertRaises(KeyError, prepared.bind, {})
|
||||||
|
else:
|
||||||
|
# post v4, the driver attempts to use UNSET_VALUE for unspecified keys
|
||||||
|
self.assertRaises(ValueError, prepared.bind, {})
|
||||||
|
|
||||||
cluster.shutdown()
|
cluster.shutdown()
|
||||||
|
|
||||||
@@ -202,9 +217,9 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
session.execute(bound)
|
session.execute(bound)
|
||||||
|
|
||||||
prepared = session.prepare(
|
prepared = session.prepare(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM test3rf.test WHERE k=?
|
SELECT * FROM test3rf.test WHERE k=?
|
||||||
""")
|
""")
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
|
|
||||||
bound = prepared.bind((1,))
|
bound = prepared.bind((1,))
|
||||||
@@ -213,6 +228,56 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
|
|
||||||
cluster.shutdown()
|
cluster.shutdown()
|
||||||
|
|
||||||
|
def test_unset_values(self):
|
||||||
|
"""
|
||||||
|
Test to validate that UNSET_VALUEs are bound, and have the expected effect
|
||||||
|
|
||||||
|
Prepare a statement and insert all values. Then follow with execute excluding
|
||||||
|
parameters. Verify that the original values are unaffected.
|
||||||
|
|
||||||
|
@since 2.6.0
|
||||||
|
|
||||||
|
@jira_ticket PYTHON-317
|
||||||
|
@expected_result UNSET_VALUE is implicitly added to bind parameters, and properly encoded, leving unset values unaffected.
|
||||||
|
|
||||||
|
@test_category prepared_statements:binding
|
||||||
|
"""
|
||||||
|
if PROTOCOL_VERSION < 4:
|
||||||
|
raise unittest.SkipTest("Binding UNSET values is not supported in protocol version < 4")
|
||||||
|
|
||||||
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
|
session = cluster.connect()
|
||||||
|
|
||||||
|
# table with at least two values so one can be used as a marker
|
||||||
|
session.execute("CREATE TABLE IF NOT EXISTS test1rf.test_unset_values (k int PRIMARY KEY, v0 int, v1 int)")
|
||||||
|
insert = session.prepare("INSERT INTO test1rf.test_unset_values (k, v0, v1) VALUES (?, ?, ?)")
|
||||||
|
select = session.prepare("SELECT * FROM test1rf.test_unset_values WHERE k=?")
|
||||||
|
|
||||||
|
bind_expected = [
|
||||||
|
# initial condition
|
||||||
|
((0, 0, 0), (0, 0, 0)),
|
||||||
|
# unset implicit
|
||||||
|
((0, 1,), (0, 1, 0)),
|
||||||
|
({'k': 0, 'v0': 2}, (0, 2, 0)),
|
||||||
|
({'k': 0, 'v1': 1}, (0, 2, 1)),
|
||||||
|
# unset explicit
|
||||||
|
((0, 3, UNSET_VALUE), (0, 3, 1)),
|
||||||
|
((0, UNSET_VALUE, 2), (0, 3, 2)),
|
||||||
|
({'k': 0, 'v0': 4, 'v1': UNSET_VALUE}, (0, 4, 2)),
|
||||||
|
({'k': 0, 'v0': UNSET_VALUE, 'v1': 3}, (0, 4, 3)),
|
||||||
|
# nulls still work
|
||||||
|
((0, None, None), (0, None, None)),
|
||||||
|
]
|
||||||
|
|
||||||
|
for params, expected in bind_expected:
|
||||||
|
session.execute(insert, params)
|
||||||
|
results = session.execute(select, (0,))
|
||||||
|
self.assertEqual(results[0], expected)
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, session.execute, select, (UNSET_VALUE, 0, 0))
|
||||||
|
|
||||||
|
cluster.shutdown()
|
||||||
|
|
||||||
def test_no_meta(self):
|
def test_no_meta(self):
|
||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
session = cluster.connect()
|
session = cluster.connect()
|
||||||
@@ -227,9 +292,9 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
session.execute(bound)
|
session.execute(bound)
|
||||||
|
|
||||||
prepared = session.prepare(
|
prepared = session.prepare(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM test3rf.test WHERE k=0
|
SELECT * FROM test3rf.test WHERE k=0
|
||||||
""")
|
""")
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
|
|
||||||
bound = prepared.bind(None)
|
bound = prepared.bind(None)
|
||||||
@@ -257,9 +322,9 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
session.execute(bound)
|
session.execute(bound)
|
||||||
|
|
||||||
prepared = session.prepare(
|
prepared = session.prepare(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM test3rf.test WHERE k=?
|
SELECT * FROM test3rf.test WHERE k=?
|
||||||
""")
|
""")
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
|
|
||||||
bound = prepared.bind({'k': 1})
|
bound = prepared.bind({'k': 1})
|
||||||
@@ -286,9 +351,9 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
prepared = session.prepare(
|
prepared = session.prepare(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM test3rf.test WHERE k=?
|
SELECT * FROM test3rf.test WHERE k=?
|
||||||
""")
|
""")
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
|
|
||||||
future = session.execute_async(prepared, (873,))
|
future = session.execute_async(prepared, (873,))
|
||||||
@@ -315,9 +380,9 @@ class PreparedStatementTests(unittest.TestCase):
|
|||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
prepared = session.prepare(
|
prepared = session.prepare(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM test3rf.test WHERE k=?
|
SELECT * FROM test3rf.test WHERE k=?
|
||||||
""")
|
""")
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
|
|
||||||
future = session.execute_async(prepared, {'k': 873})
|
future = session.execute_async(prepared, {'k': 873})
|
||||||
|
|||||||
@@ -18,8 +18,9 @@ except ImportError:
|
|||||||
import unittest # noqa
|
import unittest # noqa
|
||||||
|
|
||||||
from cassandra.encoder import Encoder
|
from cassandra.encoder import Encoder
|
||||||
from cassandra.query import bind_params, ValueSequence
|
from cassandra.protocol import ColumnMetadata
|
||||||
from cassandra.query import PreparedStatement, BoundStatement
|
from cassandra.query import (bind_params, ValueSequence, PreparedStatement,
|
||||||
|
BoundStatement, UNSET_VALUE)
|
||||||
from cassandra.cqltypes import Int32Type
|
from cassandra.cqltypes import Int32Type
|
||||||
from cassandra.util import OrderedDict
|
from cassandra.util import OrderedDict
|
||||||
|
|
||||||
@@ -73,42 +74,42 @@ class ParamBindingTest(unittest.TestCase):
|
|||||||
self.assertEqual(float(bind_params("%s", (f,), Encoder())), f)
|
self.assertEqual(float(bind_params("%s", (f,), Encoder())), f)
|
||||||
|
|
||||||
|
|
||||||
class BoundStatementTestCase(unittest.TestCase):
|
class BoundStatementTestV1(unittest.TestCase):
|
||||||
|
|
||||||
|
protocol_version=1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.prepared = PreparedStatement(column_metadata=[
|
||||||
|
ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type),
|
||||||
|
ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type),
|
||||||
|
ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type),
|
||||||
|
ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)
|
||||||
|
],
|
||||||
|
query_id=None,
|
||||||
|
routing_key_indexes=[1, 0],
|
||||||
|
query=None,
|
||||||
|
keyspace='keyspace',
|
||||||
|
protocol_version=cls.protocol_version)
|
||||||
|
cls.bound = BoundStatement(prepared_statement=cls.prepared)
|
||||||
|
|
||||||
def test_invalid_argument_type(self):
|
def test_invalid_argument_type(self):
|
||||||
keyspace = 'keyspace1'
|
values = (0, 0, 0, 'string not int')
|
||||||
column_family = 'cf1'
|
|
||||||
|
|
||||||
column_metadata = [
|
|
||||||
(keyspace, column_family, 'foo1', Int32Type),
|
|
||||||
(keyspace, column_family, 'foo2', Int32Type)
|
|
||||||
]
|
|
||||||
|
|
||||||
prepared_statement = PreparedStatement(column_metadata=column_metadata,
|
|
||||||
query_id=None,
|
|
||||||
routing_key_indexes=[],
|
|
||||||
query=None,
|
|
||||||
keyspace=keyspace,
|
|
||||||
protocol_version=2)
|
|
||||||
bound_statement = BoundStatement(prepared_statement=prepared_statement)
|
|
||||||
|
|
||||||
values = ['nonint', 1]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
bound_statement.bind(values)
|
self.bound.bind(values)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
self.assertIn('foo1', str(e))
|
self.assertIn('v0', str(e))
|
||||||
self.assertIn('Int32Type', str(e))
|
self.assertIn('Int32Type', str(e))
|
||||||
self.assertIn('str', str(e))
|
self.assertIn('str', str(e))
|
||||||
else:
|
else:
|
||||||
self.fail('Passed invalid type but exception was not thrown')
|
self.fail('Passed invalid type but exception was not thrown')
|
||||||
|
|
||||||
values = [1, ['1', '2']]
|
values = (['1', '2'], 0, 0, 0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
bound_statement.bind(values)
|
self.bound.bind(values)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
self.assertIn('foo2', str(e))
|
self.assertIn('rk0', str(e))
|
||||||
self.assertIn('Int32Type', str(e))
|
self.assertIn('Int32Type', str(e))
|
||||||
self.assertIn('list', str(e))
|
self.assertIn('list', str(e))
|
||||||
else:
|
else:
|
||||||
@@ -119,8 +120,8 @@ class BoundStatementTestCase(unittest.TestCase):
|
|||||||
column_family = 'cf1'
|
column_family = 'cf1'
|
||||||
|
|
||||||
column_metadata = [
|
column_metadata = [
|
||||||
(keyspace, column_family, 'foo1', Int32Type),
|
ColumnMetadata(keyspace, column_family, 'foo1', Int32Type),
|
||||||
(keyspace, column_family, 'foo2', Int32Type)
|
ColumnMetadata(keyspace, column_family, 'foo2', Int32Type)
|
||||||
]
|
]
|
||||||
|
|
||||||
prepared_statement = PreparedStatement(column_metadata=column_metadata,
|
prepared_statement = PreparedStatement(column_metadata=column_metadata,
|
||||||
@@ -128,28 +129,87 @@ class BoundStatementTestCase(unittest.TestCase):
|
|||||||
routing_key_indexes=[],
|
routing_key_indexes=[],
|
||||||
query=None,
|
query=None,
|
||||||
keyspace=keyspace,
|
keyspace=keyspace,
|
||||||
protocol_version=2)
|
protocol_version=self.protocol_version)
|
||||||
prepared_statement.fetch_size = 1234
|
prepared_statement.fetch_size = 1234
|
||||||
bound_statement = BoundStatement(prepared_statement=prepared_statement)
|
bound_statement = BoundStatement(prepared_statement=prepared_statement)
|
||||||
self.assertEqual(1234, bound_statement.fetch_size)
|
self.assertEqual(1234, bound_statement.fetch_size)
|
||||||
|
|
||||||
def test_too_few_parameters_for_key(self):
|
def test_too_few_parameters_for_routing_key(self):
|
||||||
keyspace = 'keyspace1'
|
self.assertRaises(ValueError, self.prepared.bind, (1,))
|
||||||
column_family = 'cf1'
|
|
||||||
|
|
||||||
column_metadata = [
|
bound = self.prepared.bind((1, 2))
|
||||||
(keyspace, column_family, 'foo1', Int32Type),
|
self.assertEqual(bound.keyspace, 'keyspace')
|
||||||
(keyspace, column_family, 'foo2', Int32Type)
|
|
||||||
]
|
|
||||||
|
|
||||||
prepared_statement = PreparedStatement(column_metadata=column_metadata,
|
def test_dict_missing_routing_key(self):
|
||||||
|
self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0})
|
||||||
|
self.assertRaises(KeyError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0})
|
||||||
|
|
||||||
|
def test_missing_value(self):
|
||||||
|
self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0})
|
||||||
|
|
||||||
|
def test_dict_extra_value(self):
|
||||||
|
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123})
|
||||||
|
self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, 0, 123))
|
||||||
|
|
||||||
|
def test_values_none(self):
|
||||||
|
# should have values
|
||||||
|
self.assertRaises(ValueError, self.bound.bind, None)
|
||||||
|
|
||||||
|
# prepared statement with no values
|
||||||
|
prepared_statement = PreparedStatement(column_metadata=[],
|
||||||
query_id=None,
|
query_id=None,
|
||||||
routing_key_indexes=[0, 1],
|
routing_key_indexes=[],
|
||||||
query=None,
|
query=None,
|
||||||
keyspace=keyspace,
|
keyspace='whatever',
|
||||||
protocol_version=2)
|
protocol_version=self.protocol_version)
|
||||||
|
bound = prepared_statement.bind(None)
|
||||||
|
self.assertListEqual(bound.values, [])
|
||||||
|
|
||||||
self.assertRaises(ValueError, prepared_statement.bind, (1,))
|
def test_bind_none(self):
|
||||||
|
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None})
|
||||||
|
self.assertEqual(self.bound.values[-1], None)
|
||||||
|
|
||||||
bound = prepared_statement.bind((1, 2))
|
old_values = self.bound.values
|
||||||
self.assertEqual(bound.keyspace, keyspace)
|
self.bound.bind((0, 0, 0, None))
|
||||||
|
self.assertIsNot(self.bound.values, old_values)
|
||||||
|
self.assertEqual(self.bound.values[-1], None)
|
||||||
|
|
||||||
|
def test_unset_value(self):
|
||||||
|
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
|
||||||
|
self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, UNSET_VALUE))
|
||||||
|
|
||||||
|
|
||||||
|
class BoundStatementTestV2(BoundStatementTestV1):
|
||||||
|
protocol_version=2
|
||||||
|
|
||||||
|
|
||||||
|
class BoundStatementTestV3(BoundStatementTestV1):
|
||||||
|
protocol_version=3
|
||||||
|
|
||||||
|
|
||||||
|
class BoundStatementTestV4(BoundStatementTestV1):
|
||||||
|
protocol_version=4
|
||||||
|
|
||||||
|
def test_dict_missing_routing_key(self):
|
||||||
|
# in v4 it implicitly binds UNSET_VALUE for missing items,
|
||||||
|
# UNSET_VALUE is ValueError for routing keys
|
||||||
|
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0})
|
||||||
|
self.assertRaises(ValueError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0})
|
||||||
|
|
||||||
|
def test_missing_value(self):
|
||||||
|
# in v4 missing values are UNSET_VALUE
|
||||||
|
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0})
|
||||||
|
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
|
||||||
|
|
||||||
|
old_values = self.bound.values
|
||||||
|
self.bound.bind((0, 0, 0))
|
||||||
|
self.assertIsNot(self.bound.values, old_values)
|
||||||
|
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
|
||||||
|
|
||||||
|
def test_unset_value(self):
|
||||||
|
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
|
||||||
|
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
|
||||||
|
|
||||||
|
old_values = self.bound.values
|
||||||
|
self.bound.bind((0, 0, 0, UNSET_VALUE))
|
||||||
|
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
|
||||||
|
|||||||
Reference in New Issue
Block a user