Distinguish btw NULL and UNSET when binding, proto v4+

PYTHON-317
This commit is contained in:
Adam Holmberg
2015-05-20 16:55:13 -05:00
parent 8b8993c810
commit 0b8b37ba89
4 changed files with 94 additions and 33 deletions

View File

@@ -64,6 +64,7 @@ WARNING_FLAG = 0x08
_message_types_by_name = {}
_message_types_by_opcode = {}
_UNSET_VALUE = object()
class _RegisterMessageType(type):
def __init__(cls, name, bases, dct):
@@ -1115,6 +1116,8 @@ def read_value(f):
def write_value(f, v):
if v is None:
write_int(f, -1)
elif v is _UNSET_VALUE:
write_int(f, -2)
else:
write_int(f, len(v))
f.write(v)

View File

@@ -24,16 +24,30 @@ import re
import struct
import time
import six
from six.moves import range
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.util import unix_time_from_uuid1
from cassandra.encoder import Encoder
import cassandra.encoder
from cassandra.protocol import _UNSET_VALUE
from cassandra.util import OrderedDict
import logging
log = logging.getLogger(__name__)
UNSET_VALUE = _UNSET_VALUE
"""
Specifies an unset value when binding a prepared statement.
Unset values are ignored, allowing prepared statements to be used without specify
See https://issues.apache.org/jira/browse/CASSANDRA-7304 for further details on semantics.
.. versionadded:: 2.6.0
Only valid when using native protocol v4+
"""
NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]')
START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*')
@@ -350,6 +364,7 @@ class PreparedStatement(object):
keyspace = None # change to prepared_keyspace in major release
routing_key_indexes = None
_routing_key_index_set = None
consistency_level = None
serial_consistency_level = None
@@ -403,11 +418,16 @@ class PreparedStatement(object):
def bind(self, values):
"""
Creates and returns a :class:`BoundStatement` instance using `values`.
The `values` parameter **must** be a sequence, such as a tuple or list,
even if there is only one value to bind.
See :meth:`BoundStatement.bind` for rules on input ``values``.
"""
return BoundStatement(self).bind(values)
def is_routing_key_index(self, i):
if self._routing_key_index_set is None:
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
return i in self._routing_key_index_set
def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'<PreparedStatement query="%s", consistency=%s>' %
@@ -455,64 +475,77 @@ class BoundStatement(Statement):
"""
Binds a sequence of values for the prepared statement parameters
and returns this instance. Note that `values` *must* be:
* a sequence, even if you are only binding one value, or
* a dict that relates 1-to-1 between dict keys and columns
.. versionchanged:: 2.6.0
:data:`~.UNSET_VALUE` was introduced. These can be bound as positional parameters
in a sequence, or by name in a dict. Additionally, when using protocol v4+:
* short sequences will be extended to match bind parameters with UNSET_VALUE
* names may be omitted from a dict with UNSET_VALUE implied.
"""
if values is None:
values = ()
col_meta = self.prepared_statement.column_metadata
proto_version = self.prepared_statement.protocol_version
col_meta = self.prepared_statement.column_metadata
col_meta_len = len(col_meta)
value_len = len(values)
# special case for binding dicts
if isinstance(values, dict):
dict_values = values
unbound_values = values.copy()
values = []
# sort values accordingly
for col in col_meta:
try:
values.append(dict_values[col.name])
values.append(unbound_values.pop(col.name))
except KeyError:
raise KeyError(
'Column name `%s` not found in bound dict.' %
(col.name))
if proto_version >= 4:
values.append(UNSET_VALUE)
else:
raise KeyError(
'Column name `%s` not found in bound dict.' %
(col.name))
# ensure a 1-to-1 dict keys to columns relationship
if len(dict_values) != len(col_meta):
# find expected columns
if unbound_values:
raise ValueError("Unexpected arguments provided to bind(): %s" % unbound_values.keys())
value_len = len(values)
if value_len < col_meta_len:
columns = set(col.name for col in col_meta)
difference = set(columns).difference(dict_values.keys())
raise ValueError("Too few arguments provided to bind() (got %d, expected %d). "
"Missing keys %s." % (value_len, col_meta_len, difference))
# generate error message
if len(dict_values) > len(col_meta):
difference = set(dict_values.keys()).difference(columns)
msg = "Too many arguments provided to bind() (got %d, expected %d). " + \
"Unexpected keys %s."
else:
difference = set(columns).difference(dict_values.keys())
msg = "Too few arguments provided to bind() (got %d, expected %d). " + \
"Expected keys %s."
# exit with error message
msg = msg % (len(values), len(col_meta), difference)
raise ValueError(msg)
if len(values) > len(col_meta):
if value_len > col_meta_len:
raise ValueError(
"Too many arguments provided to bind() (got %d, expected %d)" %
(len(values), len(col_meta)))
if self.prepared_statement.routing_key_indexes and \
len(values) < len(self.prepared_statement.routing_key_indexes):
# this is fail-fast for clarity pre-v4. When v4 can be assumed,
# the error will be better reported when UNSET_VALUE is implicitly added.
if proto_version < 4 and self.prepared_statement.routing_key_indexes and \
value_len < len(self.prepared_statement.routing_key_indexes):
raise ValueError(
"Too few arguments provided to bind() (got %d, required %d for routing key)" %
(len(values), len(self.prepared_statement.routing_key_indexes)))
(value_len, len(self.prepared_statement.routing_key_indexes)))
self.raw_values = values
self.values = []
for value, col_spec in zip(values, col_meta):
if value is None:
self.values.append(None)
elif value is UNSET_VALUE:
if proto_version >= 4:
self._append_unset_value()
else:
raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version)
else:
try:
self.values.append(col_spec.type.serialize(value, proto_version))
@@ -522,8 +555,21 @@ class BoundStatement(Statement):
'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc))
raise TypeError(message)
if proto_version >= 4:
diff = col_meta_len - len(self.values)
if diff:
for _ in range(diff):
self._append_unset_value()
return self
def _append_unset_value(self):
next_index = len(self.values)
if self.prepared_statement.is_routing_key_index(next_index):
col_meta = self.prepared_statement.column_metadata[next_index]
raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name)
self.values.append(UNSET_VALUE)
@property
def routing_key(self):
if not self.prepared_statement.routing_key_indexes:

View File

@@ -23,6 +23,9 @@
.. autoclass:: BoundStatement
:members:
.. autodata:: UNSET_VALUE
:annotation:
.. autoclass:: BatchStatement (batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None)
:members:

View File

@@ -172,15 +172,24 @@ class PreparedStatementTests(unittest.TestCase):
prepared = session.prepare(
"""
INSERT INTO test3rf.test (v) VALUES (?)
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""")
self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2})
# too many values
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2, 'v2': 3})
# right number, but one does not belong
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v2': 3})
# also catch too few variables with dicts
self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(KeyError, prepared.bind, {})
if PROTOCOL_VERSION < 4:
self.assertRaises(KeyError, prepared.bind, {})
else:
# post v4, the driver attempts to use UNSET_VALUE for unspecified keys
self.assertRaises(ValueError, prepared.bind, {})
cluster.shutdown()