# Copyright 2013-2017 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. try: import unittest2 as unittest except ImportError: import unittest # noqa from cassandra.encoder import Encoder from cassandra.protocol import ColumnMetadata from cassandra.query import (bind_params, ValueSequence, PreparedStatement, BoundStatement, UNSET_VALUE) from cassandra.cqltypes import Int32Type from cassandra.util import OrderedDict from six.moves import xrange import six class ParamBindingTest(unittest.TestCase): def test_bind_sequence(self): result = bind_params("%s %s %s", (1, "a", 2.0), Encoder()) self.assertEqual(result, "1 'a' 2.0") def test_bind_map(self): result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), Encoder()) self.assertEqual(result, "1 'a' 2.0") def test_sequence_param(self): result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), Encoder()) self.assertEqual(result, "(1, 'a', 2.0)") def test_generator_param(self): result = bind_params("%s", ((i for i in xrange(3)),), Encoder()) self.assertEqual(result, "[0, 1, 2]") def test_none_param(self): result = bind_params("%s", (None,), Encoder()) self.assertEqual(result, "NULL") def test_list_collection(self): result = bind_params("%s", (['a', 'b', 'c'],), Encoder()) self.assertEqual(result, "['a', 'b', 'c']") def test_set_collection(self): result = bind_params("%s", (set(['a', 'b']),), Encoder()) self.assertIn(result, ("{'a', 'b'}", "{'b', 'a'}")) def test_map_collection(self): vals = OrderedDict() vals['a'] = 'a' vals['b'] = 'b' vals['c'] = 'c' result = bind_params("%s", (vals,), Encoder()) self.assertEqual(result, "{'a': 'a', 'b': 'b', 'c': 'c'}") def test_quote_escaping(self): result = bind_params("%s", ("""'ef''ef"ef""ef'""",), Encoder()) self.assertEqual(result, """'''ef''''ef"ef""ef'''""") def test_float_precision(self): f = 3.4028234663852886e+38 self.assertEqual(float(bind_params("%s", (f,), Encoder())), f) class BoundStatementTestV1(unittest.TestCase): protocol_version=1 @classmethod def setUpClass(cls): cls.prepared = PreparedStatement(column_metadata=[ ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type), ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type), ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type), ColumnMetadata('keyspace', 'cf', 'v0', Int32Type) ], query_id=None, routing_key_indexes=[1, 0], query=None, keyspace='keyspace', protocol_version=cls.protocol_version, result_metadata=None) cls.bound = BoundStatement(prepared_statement=cls.prepared) def test_invalid_argument_type(self): values = (0, 0, 0, 'string not int') try: self.bound.bind(values) except TypeError as e: self.assertIn('v0', str(e)) self.assertIn('Int32Type', str(e)) self.assertIn('str', str(e)) else: self.fail('Passed invalid type but exception was not thrown') values = (['1', '2'], 0, 0, 0) try: self.bound.bind(values) except TypeError as e: self.assertIn('rk0', str(e)) self.assertIn('Int32Type', str(e)) self.assertIn('list', str(e)) else: self.fail('Passed invalid type but exception was not thrown') def test_inherit_fetch_size(self): keyspace = 'keyspace1' column_family = 'cf1' column_metadata = [ ColumnMetadata(keyspace, column_family, 'foo1', Int32Type), ColumnMetadata(keyspace, column_family, 'foo2', Int32Type) ] prepared_statement = PreparedStatement(column_metadata=column_metadata, query_id=None, routing_key_indexes=[], query=None, keyspace=keyspace, protocol_version=self.protocol_version, result_metadata=None) prepared_statement.fetch_size = 1234 bound_statement = BoundStatement(prepared_statement=prepared_statement) self.assertEqual(1234, bound_statement.fetch_size) def test_too_few_parameters_for_routing_key(self): self.assertRaises(ValueError, self.prepared.bind, (1,)) bound = self.prepared.bind((1, 2)) self.assertEqual(bound.keyspace, 'keyspace') def test_dict_missing_routing_key(self): self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0}) self.assertRaises(KeyError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0}) def test_missing_value(self): self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0}) def test_extra_value(self): self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123}) # okay to have extra keys in dict self.assertEqual(self.bound.values, [six.b('\x00') * 4] * 4) # four encoded zeros self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, 0, 123)) def test_values_none(self): # should have values self.assertRaises(ValueError, self.bound.bind, None) # prepared statement with no values prepared_statement = PreparedStatement(column_metadata=[], query_id=None, routing_key_indexes=[], query=None, keyspace='whatever', protocol_version=self.protocol_version, result_metadata=None) bound = prepared_statement.bind(None) self.assertListEqual(bound.values, []) def test_bind_none(self): self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None}) self.assertEqual(self.bound.values[-1], None) old_values = self.bound.values self.bound.bind((0, 0, 0, None)) self.assertIsNot(self.bound.values, old_values) self.assertEqual(self.bound.values[-1], None) def test_unset_value(self): self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, UNSET_VALUE)) class BoundStatementTestV2(BoundStatementTestV1): protocol_version=2 class BoundStatementTestV3(BoundStatementTestV1): protocol_version=3 class BoundStatementTestV4(BoundStatementTestV1): protocol_version=4 def test_dict_missing_routing_key(self): # in v4 it implicitly binds UNSET_VALUE for missing items, # UNSET_VALUE is ValueError for routing keys self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0}) self.assertRaises(ValueError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0}) def test_missing_value(self): # in v4 missing values are UNSET_VALUE self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0}) self.assertEqual(self.bound.values[-1], UNSET_VALUE) old_values = self.bound.values self.bound.bind((0, 0, 0)) self.assertIsNot(self.bound.values, old_values) self.assertEqual(self.bound.values[-1], UNSET_VALUE) def test_unset_value(self): self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) self.assertEqual(self.bound.values[-1], UNSET_VALUE) old_values = self.bound.values self.bound.bind((0, 0, 0, UNSET_VALUE)) self.assertEqual(self.bound.values[-1], UNSET_VALUE)