Merge pull request #321 from datastax/317

PYTHON-317 - Distinguish NULL and UNSET in bound parameters, protocol v4+
This commit is contained in:
Adam Holmberg
2015-05-26 16:29:08 -05:00
5 changed files with 280 additions and 115 deletions

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import absolute_import # to enable import io from stdlib
from collections import namedtuple
import logging
import socket
from uuid import UUID
@@ -49,6 +50,7 @@ class NotSupportedError(Exception):
class InternalError(Exception):
pass
ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type'])
HEADER_DIRECTION_FROM_CLIENT = 0x00
HEADER_DIRECTION_TO_CLIENT = 0x80
@@ -62,6 +64,7 @@ WARNING_FLAG = 0x08
_message_types_by_name = {}
_message_types_by_opcode = {}
_UNSET_VALUE = object()
class _RegisterMessageType(type):
def __init__(cls, name, bases, dct):
@@ -727,7 +730,7 @@ class ResultMessage(_MessageType):
colcfname = read_string(f)
colname = read_string(f)
coltype = cls.read_type(f, user_type_map)
column_metadata.append((colksname, colcfname, colname, coltype))
column_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype))
return column_metadata, pk_indexes
@classmethod
@@ -1113,6 +1116,8 @@ def read_value(f):
def write_value(f, v):
if v is None:
write_int(f, -1)
elif v is _UNSET_VALUE:
write_int(f, -2)
else:
write_int(f, len(v))
f.write(v)

View File

@@ -24,16 +24,30 @@ import re
import struct
import time
import six
from six.moves import range
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.util import unix_time_from_uuid1
from cassandra.encoder import Encoder
import cassandra.encoder
from cassandra.protocol import _UNSET_VALUE
from cassandra.util import OrderedDict
import logging
log = logging.getLogger(__name__)
UNSET_VALUE = _UNSET_VALUE
"""
Specifies an unset value when binding a prepared statement.
Unset values are ignored, allowing prepared statements to be used without specify
See https://issues.apache.org/jira/browse/CASSANDRA-7304 for further details on semantics.
.. versionadded:: 2.6.0
Only valid when using native protocol v4+
"""
NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]')
START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*')
@@ -350,6 +364,7 @@ class PreparedStatement(object):
keyspace = None # change to prepared_keyspace in major release
routing_key_indexes = None
_routing_key_index_set = None
consistency_level = None
serial_consistency_level = None
@@ -377,18 +392,17 @@ class PreparedStatement(object):
if pk_indexes:
routing_key_indexes = pk_indexes
else:
partition_key_columns = None
routing_key_indexes = None
ks_name, table_name, _, _ = column_metadata[0]
ks_meta = cluster_metadata.keyspaces.get(ks_name)
first_col = column_metadata[0]
ks_meta = cluster_metadata.keyspaces.get(first_col.keyspace_name)
if ks_meta:
table_meta = ks_meta.tables.get(table_name)
table_meta = ks_meta.tables.get(first_col.table_name)
if table_meta:
partition_key_columns = table_meta.partition_key
# make a map of {column_name: index} for each column in the statement
statement_indexes = dict((c[2], i) for i, c in enumerate(column_metadata))
statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata))
# a list of which indexes in the statement correspond to partition key items
try:
@@ -403,11 +417,16 @@ class PreparedStatement(object):
def bind(self, values):
"""
Creates and returns a :class:`BoundStatement` instance using `values`.
The `values` parameter **must** be a sequence, such as a tuple or list,
even if there is only one value to bind.
See :meth:`BoundStatement.bind` for rules on input ``values``.
"""
return BoundStatement(self).bind(values)
def is_routing_key_index(self, i):
if self._routing_key_index_set is None:
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
return i in self._routing_key_index_set
def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'<PreparedStatement query="%s", consistency=%s>' %
@@ -447,7 +466,7 @@ class BoundStatement(Statement):
meta = prepared_statement.column_metadata
if meta:
self.keyspace = meta[0][0]
self.keyspace = meta[0].keyspace_name
Statement.__init__(self, *args, **kwargs)
@@ -455,82 +474,95 @@ class BoundStatement(Statement):
"""
Binds a sequence of values for the prepared statement parameters
and returns this instance. Note that `values` *must* be:
* a sequence, even if you are only binding one value, or
* a dict that relates 1-to-1 between dict keys and columns
.. versionchanged:: 2.6.0
:data:`~.UNSET_VALUE` was introduced. These can be bound as positional parameters
in a sequence, or by name in a dict. Additionally, when using protocol v4+:
* short sequences will be extended to match bind parameters with UNSET_VALUE
* names may be omitted from a dict with UNSET_VALUE implied.
"""
if values is None:
values = ()
col_meta = self.prepared_statement.column_metadata
proto_version = self.prepared_statement.protocol_version
col_meta = self.prepared_statement.column_metadata
col_meta_len = len(col_meta)
value_len = len(values)
# special case for binding dicts
if isinstance(values, dict):
dict_values = values
unbound_values = values.copy()
values = []
# sort values accordingly
for col in col_meta:
try:
values.append(dict_values[col[2]])
values.append(unbound_values.pop(col.name))
except KeyError:
raise KeyError(
'Column name `%s` not found in bound dict.' %
(col[2]))
if proto_version >= 4:
values.append(UNSET_VALUE)
else:
raise KeyError(
'Column name `%s` not found in bound dict.' %
(col.name))
# ensure a 1-to-1 dict keys to columns relationship
if len(dict_values) != len(col_meta):
# find expected columns
columns = set()
for col in col_meta:
columns.add(col[2])
value_len = len(values)
# generate error message
if len(dict_values) > len(col_meta):
difference = set(dict_values.keys()).difference(columns)
msg = "Too many arguments provided to bind() (got %d, expected %d). " + \
"Unexpected keys %s."
else:
difference = set(columns).difference(dict_values.keys())
msg = "Too few arguments provided to bind() (got %d, expected %d). " + \
"Expected keys %s."
if unbound_values:
raise ValueError("Unexpected arguments provided to bind(): %s" % unbound_values.keys())
# exit with error message
msg = msg % (len(values), len(col_meta), difference)
raise ValueError(msg)
if len(values) > len(col_meta):
if value_len > col_meta_len:
raise ValueError(
"Too many arguments provided to bind() (got %d, expected %d)" %
(len(values), len(col_meta)))
if self.prepared_statement.routing_key_indexes and \
len(values) < len(self.prepared_statement.routing_key_indexes):
# this is fail-fast for clarity pre-v4. When v4 can be assumed,
# the error will be better reported when UNSET_VALUE is implicitly added.
if proto_version < 4 and self.prepared_statement.routing_key_indexes and \
value_len < len(self.prepared_statement.routing_key_indexes):
raise ValueError(
"Too few arguments provided to bind() (got %d, required %d for routing key)" %
(len(values), len(self.prepared_statement.routing_key_indexes)))
(value_len, len(self.prepared_statement.routing_key_indexes)))
self.raw_values = values
self.values = []
for value, col_spec in zip(values, col_meta):
if value is None:
self.values.append(None)
elif value is UNSET_VALUE:
if proto_version >= 4:
self._append_unset_value()
else:
raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version)
else:
col_type = col_spec[-1]
try:
self.values.append(col_type.serialize(value, proto_version))
self.values.append(col_spec.type.serialize(value, proto_version))
except (TypeError, struct.error) as exc:
col_name = col_spec[2]
expected_type = col_type
actual_type = type(value)
message = ('Received an argument of invalid type for column "%s". '
'Expected: %s, Got: %s; (%s)' % (col_name, expected_type, actual_type, exc))
'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc))
raise TypeError(message)
if proto_version >= 4:
diff = col_meta_len - len(self.values)
if diff:
for _ in range(diff):
self._append_unset_value()
return self
def _append_unset_value(self):
next_index = len(self.values)
if self.prepared_statement.is_routing_key_index(next_index):
col_meta = self.prepared_statement.column_metadata[next_index]
raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name)
self.values.append(UNSET_VALUE)
@property
def routing_key(self):
if not self.prepared_statement.routing_key_indexes:

View File

@@ -23,6 +23,9 @@
.. autoclass:: BoundStatement
:members:
.. autodata:: UNSET_VALUE
:annotation:
.. autoclass:: BatchStatement (batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None)
:members:

View File

@@ -22,12 +22,13 @@ except ImportError:
from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.query import PreparedStatement
from cassandra.query import PreparedStatement, UNSET_VALUE
def setup_module():
use_singledc()
class PreparedStatementTests(unittest.TestCase):
def test_basic(self):
@@ -65,9 +66,9 @@ class PreparedStatementTests(unittest.TestCase):
session.execute(bound)
prepared = session.prepare(
"""
SELECT * FROM cf0 WHERE a=?
""")
"""
SELECT * FROM cf0 WHERE a=?
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind(('a'))
@@ -90,9 +91,9 @@ class PreparedStatementTests(unittest.TestCase):
session.execute(bound)
prepared = session.prepare(
"""
SELECT * FROM cf0 WHERE a=?
""")
"""
SELECT * FROM cf0 WHERE a=?
""")
self.assertIsInstance(prepared, PreparedStatement)
@@ -163,7 +164,7 @@ class PreparedStatementTests(unittest.TestCase):
def test_too_many_bind_values_dicts(self):
"""
Ensure a ValueError is thrown when attempting to bind too many variables
Ensure an error is thrown when attempting to bind the wrong values
with dict bindings
"""
@@ -172,15 +173,29 @@ class PreparedStatementTests(unittest.TestCase):
prepared = session.prepare(
"""
INSERT INTO test3rf.test (v) VALUES (?)
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""")
self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2})
# too many values
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2, 'v2': 3})
# right number, but one does not belong
if PROTOCOL_VERSION < 4:
# pre v4, the driver bails with key error when 'v' is found missing
self.assertRaises(KeyError, prepared.bind, {'k': 1, 'v2': 3})
else:
# post v4, the driver uses UNSET_VALUE for 'v' and bails when 'v2' is unbound
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v2': 3})
# also catch too few variables with dicts
self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(KeyError, prepared.bind, {})
if PROTOCOL_VERSION < 4:
self.assertRaises(KeyError, prepared.bind, {})
else:
# post v4, the driver attempts to use UNSET_VALUE for unspecified keys
self.assertRaises(ValueError, prepared.bind, {})
cluster.shutdown()
@@ -202,9 +217,9 @@ class PreparedStatementTests(unittest.TestCase):
session.execute(bound)
prepared = session.prepare(
"""
SELECT * FROM test3rf.test WHERE k=?
""")
"""
SELECT * FROM test3rf.test WHERE k=?
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1,))
@@ -213,6 +228,56 @@ class PreparedStatementTests(unittest.TestCase):
cluster.shutdown()
def test_unset_values(self):
"""
Test to validate that UNSET_VALUEs are bound, and have the expected effect
Prepare a statement and insert all values. Then follow with execute excluding
parameters. Verify that the original values are unaffected.
@since 2.6.0
@jira_ticket PYTHON-317
@expected_result UNSET_VALUE is implicitly added to bind parameters, and properly encoded, leving unset values unaffected.
@test_category prepared_statements:binding
"""
if PROTOCOL_VERSION < 4:
raise unittest.SkipTest("Binding UNSET values is not supported in protocol version < 4")
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
# table with at least two values so one can be used as a marker
session.execute("CREATE TABLE IF NOT EXISTS test1rf.test_unset_values (k int PRIMARY KEY, v0 int, v1 int)")
insert = session.prepare("INSERT INTO test1rf.test_unset_values (k, v0, v1) VALUES (?, ?, ?)")
select = session.prepare("SELECT * FROM test1rf.test_unset_values WHERE k=?")
bind_expected = [
# initial condition
((0, 0, 0), (0, 0, 0)),
# unset implicit
((0, 1,), (0, 1, 0)),
({'k': 0, 'v0': 2}, (0, 2, 0)),
({'k': 0, 'v1': 1}, (0, 2, 1)),
# unset explicit
((0, 3, UNSET_VALUE), (0, 3, 1)),
((0, UNSET_VALUE, 2), (0, 3, 2)),
({'k': 0, 'v0': 4, 'v1': UNSET_VALUE}, (0, 4, 2)),
({'k': 0, 'v0': UNSET_VALUE, 'v1': 3}, (0, 4, 3)),
# nulls still work
((0, None, None), (0, None, None)),
]
for params, expected in bind_expected:
session.execute(insert, params)
results = session.execute(select, (0,))
self.assertEqual(results[0], expected)
self.assertRaises(ValueError, session.execute, select, (UNSET_VALUE, 0, 0))
cluster.shutdown()
def test_no_meta(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
@@ -227,9 +292,9 @@ class PreparedStatementTests(unittest.TestCase):
session.execute(bound)
prepared = session.prepare(
"""
SELECT * FROM test3rf.test WHERE k=0
""")
"""
SELECT * FROM test3rf.test WHERE k=0
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind(None)
@@ -257,9 +322,9 @@ class PreparedStatementTests(unittest.TestCase):
session.execute(bound)
prepared = session.prepare(
"""
SELECT * FROM test3rf.test WHERE k=?
""")
"""
SELECT * FROM test3rf.test WHERE k=?
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind({'k': 1})
@@ -286,9 +351,9 @@ class PreparedStatementTests(unittest.TestCase):
future.result()
prepared = session.prepare(
"""
SELECT * FROM test3rf.test WHERE k=?
""")
"""
SELECT * FROM test3rf.test WHERE k=?
""")
self.assertIsInstance(prepared, PreparedStatement)
future = session.execute_async(prepared, (873,))
@@ -315,9 +380,9 @@ class PreparedStatementTests(unittest.TestCase):
future.result()
prepared = session.prepare(
"""
SELECT * FROM test3rf.test WHERE k=?
""")
"""
SELECT * FROM test3rf.test WHERE k=?
""")
self.assertIsInstance(prepared, PreparedStatement)
future = session.execute_async(prepared, {'k': 873})

View File

@@ -18,8 +18,9 @@ except ImportError:
import unittest # noqa
from cassandra.encoder import Encoder
from cassandra.query import bind_params, ValueSequence
from cassandra.query import PreparedStatement, BoundStatement
from cassandra.protocol import ColumnMetadata
from cassandra.query import (bind_params, ValueSequence, PreparedStatement,
BoundStatement, UNSET_VALUE)
from cassandra.cqltypes import Int32Type
from cassandra.util import OrderedDict
@@ -73,42 +74,42 @@ class ParamBindingTest(unittest.TestCase):
self.assertEqual(float(bind_params("%s", (f,), Encoder())), f)
class BoundStatementTestCase(unittest.TestCase):
class BoundStatementTestV1(unittest.TestCase):
protocol_version=1
@classmethod
def setUpClass(cls):
cls.prepared = PreparedStatement(column_metadata=[
ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type),
ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type),
ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type),
ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)
],
query_id=None,
routing_key_indexes=[1, 0],
query=None,
keyspace='keyspace',
protocol_version=cls.protocol_version)
cls.bound = BoundStatement(prepared_statement=cls.prepared)
def test_invalid_argument_type(self):
keyspace = 'keyspace1'
column_family = 'cf1'
column_metadata = [
(keyspace, column_family, 'foo1', Int32Type),
(keyspace, column_family, 'foo2', Int32Type)
]
prepared_statement = PreparedStatement(column_metadata=column_metadata,
query_id=None,
routing_key_indexes=[],
query=None,
keyspace=keyspace,
protocol_version=2)
bound_statement = BoundStatement(prepared_statement=prepared_statement)
values = ['nonint', 1]
values = (0, 0, 0, 'string not int')
try:
bound_statement.bind(values)
self.bound.bind(values)
except TypeError as e:
self.assertIn('foo1', str(e))
self.assertIn('v0', str(e))
self.assertIn('Int32Type', str(e))
self.assertIn('str', str(e))
else:
self.fail('Passed invalid type but exception was not thrown')
values = [1, ['1', '2']]
values = (['1', '2'], 0, 0, 0)
try:
bound_statement.bind(values)
self.bound.bind(values)
except TypeError as e:
self.assertIn('foo2', str(e))
self.assertIn('rk0', str(e))
self.assertIn('Int32Type', str(e))
self.assertIn('list', str(e))
else:
@@ -119,8 +120,8 @@ class BoundStatementTestCase(unittest.TestCase):
column_family = 'cf1'
column_metadata = [
(keyspace, column_family, 'foo1', Int32Type),
(keyspace, column_family, 'foo2', Int32Type)
ColumnMetadata(keyspace, column_family, 'foo1', Int32Type),
ColumnMetadata(keyspace, column_family, 'foo2', Int32Type)
]
prepared_statement = PreparedStatement(column_metadata=column_metadata,
@@ -128,28 +129,87 @@ class BoundStatementTestCase(unittest.TestCase):
routing_key_indexes=[],
query=None,
keyspace=keyspace,
protocol_version=2)
protocol_version=self.protocol_version)
prepared_statement.fetch_size = 1234
bound_statement = BoundStatement(prepared_statement=prepared_statement)
self.assertEqual(1234, bound_statement.fetch_size)
def test_too_few_parameters_for_key(self):
keyspace = 'keyspace1'
column_family = 'cf1'
def test_too_few_parameters_for_routing_key(self):
self.assertRaises(ValueError, self.prepared.bind, (1,))
column_metadata = [
(keyspace, column_family, 'foo1', Int32Type),
(keyspace, column_family, 'foo2', Int32Type)
]
bound = self.prepared.bind((1, 2))
self.assertEqual(bound.keyspace, 'keyspace')
prepared_statement = PreparedStatement(column_metadata=column_metadata,
def test_dict_missing_routing_key(self):
self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0})
self.assertRaises(KeyError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0})
def test_missing_value(self):
self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0})
def test_dict_extra_value(self):
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123})
self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, 0, 123))
def test_values_none(self):
# should have values
self.assertRaises(ValueError, self.bound.bind, None)
# prepared statement with no values
prepared_statement = PreparedStatement(column_metadata=[],
query_id=None,
routing_key_indexes=[0, 1],
routing_key_indexes=[],
query=None,
keyspace=keyspace,
protocol_version=2)
keyspace='whatever',
protocol_version=self.protocol_version)
bound = prepared_statement.bind(None)
self.assertListEqual(bound.values, [])
self.assertRaises(ValueError, prepared_statement.bind, (1,))
def test_bind_none(self):
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None})
self.assertEqual(self.bound.values[-1], None)
bound = prepared_statement.bind((1, 2))
self.assertEqual(bound.keyspace, keyspace)
old_values = self.bound.values
self.bound.bind((0, 0, 0, None))
self.assertIsNot(self.bound.values, old_values)
self.assertEqual(self.bound.values[-1], None)
def test_unset_value(self):
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, UNSET_VALUE))
class BoundStatementTestV2(BoundStatementTestV1):
protocol_version=2
class BoundStatementTestV3(BoundStatementTestV1):
protocol_version=3
class BoundStatementTestV4(BoundStatementTestV1):
protocol_version=4
def test_dict_missing_routing_key(self):
# in v4 it implicitly binds UNSET_VALUE for missing items,
# UNSET_VALUE is ValueError for routing keys
self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0})
self.assertRaises(ValueError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0})
def test_missing_value(self):
# in v4 missing values are UNSET_VALUE
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0})
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
old_values = self.bound.values
self.bound.bind((0, 0, 0))
self.assertIsNot(self.bound.values, old_values)
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
def test_unset_value(self):
self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
self.assertEqual(self.bound.values[-1], UNSET_VALUE)
old_values = self.bound.values
self.bound.bind((0, 0, 0, UNSET_VALUE))
self.assertEqual(self.bound.values[-1], UNSET_VALUE)