220 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			220 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2013-2017 DataStax, Inc.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
# http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
try:
 | 
						|
    import unittest2 as unittest
 | 
						|
except ImportError:
 | 
						|
    import unittest # noqa
 | 
						|
 | 
						|
from cassandra.encoder import Encoder
 | 
						|
from cassandra.protocol import ColumnMetadata
 | 
						|
from cassandra.query import (bind_params, ValueSequence, PreparedStatement,
 | 
						|
                             BoundStatement, UNSET_VALUE)
 | 
						|
from cassandra.cqltypes import Int32Type
 | 
						|
from cassandra.util import OrderedDict
 | 
						|
 | 
						|
from six.moves import xrange
 | 
						|
import six
 | 
						|
 | 
						|
 | 
						|
class ParamBindingTest(unittest.TestCase):
 | 
						|
 | 
						|
    def test_bind_sequence(self):
 | 
						|
        result = bind_params("%s %s %s", (1, "a", 2.0), Encoder())
 | 
						|
        self.assertEqual(result, "1 'a' 2.0")
 | 
						|
 | 
						|
    def test_bind_map(self):
 | 
						|
        result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), Encoder())
 | 
						|
        self.assertEqual(result, "1 'a' 2.0")
 | 
						|
 | 
						|
    def test_sequence_param(self):
 | 
						|
        result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), Encoder())
 | 
						|
        self.assertEqual(result, "(1, 'a', 2.0)")
 | 
						|
 | 
						|
    def test_generator_param(self):
 | 
						|
        result = bind_params("%s", ((i for i in xrange(3)),), Encoder())
 | 
						|
        self.assertEqual(result, "[0, 1, 2]")
 | 
						|
 | 
						|
    def test_none_param(self):
 | 
						|
        result = bind_params("%s", (None,), Encoder())
 | 
						|
        self.assertEqual(result, "NULL")
 | 
						|
 | 
						|
    def test_list_collection(self):
 | 
						|
        result = bind_params("%s", (['a', 'b', 'c'],), Encoder())
 | 
						|
        self.assertEqual(result, "['a', 'b', 'c']")
 | 
						|
 | 
						|
    def test_set_collection(self):
 | 
						|
        result = bind_params("%s", (set(['a', 'b']),), Encoder())
 | 
						|
        self.assertIn(result, ("{'a', 'b'}", "{'b', 'a'}"))
 | 
						|
 | 
						|
    def test_map_collection(self):
 | 
						|
        vals = OrderedDict()
 | 
						|
        vals['a'] = 'a'
 | 
						|
        vals['b'] = 'b'
 | 
						|
        vals['c'] = 'c'
 | 
						|
        result = bind_params("%s", (vals,), Encoder())
 | 
						|
        self.assertEqual(result, "{'a': 'a', 'b': 'b', 'c': 'c'}")
 | 
						|
 | 
						|
    def test_quote_escaping(self):
 | 
						|
        result = bind_params("%s", ("""'ef''ef"ef""ef'""",), Encoder())
 | 
						|
        self.assertEqual(result, """'''ef''''ef"ef""ef'''""")
 | 
						|
 | 
						|
    def test_float_precision(self):
 | 
						|
        f = 3.4028234663852886e+38
 | 
						|
        self.assertEqual(float(bind_params("%s", (f,), Encoder())), f)
 | 
						|
 | 
						|
 | 
						|
class BoundStatementTestV1(unittest.TestCase):
 | 
						|
 | 
						|
    protocol_version=1
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def setUpClass(cls):
 | 
						|
        cls.prepared = PreparedStatement(column_metadata=[
 | 
						|
                                             ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type),
 | 
						|
                                             ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type),
 | 
						|
                                             ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type),
 | 
						|
                                             ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)
 | 
						|
                                         ],
 | 
						|
                                         query_id=None,
 | 
						|
                                         routing_key_indexes=[1, 0],
 | 
						|
                                         query=None,
 | 
						|
                                         keyspace='keyspace',
 | 
						|
                                         protocol_version=cls.protocol_version, result_metadata=None)
 | 
						|
        cls.bound = BoundStatement(prepared_statement=cls.prepared)
 | 
						|
 | 
						|
    def test_invalid_argument_type(self):
 | 
						|
        values = (0, 0, 0, 'string not int')
 | 
						|
        try:
 | 
						|
            self.bound.bind(values)
 | 
						|
        except TypeError as e:
 | 
						|
            self.assertIn('v0', str(e))
 | 
						|
            self.assertIn('Int32Type', str(e))
 | 
						|
            self.assertIn('str', str(e))
 | 
						|
        else:
 | 
						|
            self.fail('Passed invalid type but exception was not thrown')
 | 
						|
 | 
						|
        values = (['1', '2'], 0, 0, 0)
 | 
						|
 | 
						|
        try:
 | 
						|
            self.bound.bind(values)
 | 
						|
        except TypeError as e:
 | 
						|
            self.assertIn('rk0', str(e))
 | 
						|
            self.assertIn('Int32Type', str(e))
 | 
						|
            self.assertIn('list', str(e))
 | 
						|
        else:
 | 
						|
            self.fail('Passed invalid type but exception was not thrown')
 | 
						|
 | 
						|
    def test_inherit_fetch_size(self):
 | 
						|
        keyspace = 'keyspace1'
 | 
						|
        column_family = 'cf1'
 | 
						|
 | 
						|
        column_metadata = [
 | 
						|
            ColumnMetadata(keyspace, column_family, 'foo1', Int32Type),
 | 
						|
            ColumnMetadata(keyspace, column_family, 'foo2', Int32Type)
 | 
						|
        ]
 | 
						|
 | 
						|
        prepared_statement = PreparedStatement(column_metadata=column_metadata,
 | 
						|
                                               query_id=None,
 | 
						|
                                               routing_key_indexes=[],
 | 
						|
                                               query=None,
 | 
						|
                                               keyspace=keyspace,
 | 
						|
                                               protocol_version=self.protocol_version,
 | 
						|
                                               result_metadata=None)
 | 
						|
        prepared_statement.fetch_size = 1234
 | 
						|
        bound_statement = BoundStatement(prepared_statement=prepared_statement)
 | 
						|
        self.assertEqual(1234, bound_statement.fetch_size)
 | 
						|
 | 
						|
    def test_too_few_parameters_for_routing_key(self):
 | 
						|
        self.assertRaises(ValueError, self.prepared.bind, (1,))
 | 
						|
 | 
						|
        bound = self.prepared.bind((1, 2))
 | 
						|
        self.assertEqual(bound.keyspace, 'keyspace')
 | 
						|
 | 
						|
    def test_dict_missing_routing_key(self):
 | 
						|
        self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0})
 | 
						|
        self.assertRaises(KeyError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0})
 | 
						|
 | 
						|
    def test_missing_value(self):
 | 
						|
        self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0})
 | 
						|
 | 
						|
    def test_extra_value(self):
 | 
						|
        self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123})  # okay to have extra keys in dict
 | 
						|
        self.assertEqual(self.bound.values, [six.b('\x00') * 4] * 4)  # four encoded zeros
 | 
						|
        self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, 0, 123))
 | 
						|
 | 
						|
    def test_values_none(self):
 | 
						|
        # should have values
 | 
						|
        self.assertRaises(ValueError, self.bound.bind, None)
 | 
						|
 | 
						|
        # prepared statement with no values
 | 
						|
        prepared_statement = PreparedStatement(column_metadata=[],
 | 
						|
                                               query_id=None,
 | 
						|
                                               routing_key_indexes=[],
 | 
						|
                                               query=None,
 | 
						|
                                               keyspace='whatever',
 | 
						|
                                               protocol_version=self.protocol_version,
 | 
						|
                                               result_metadata=None)
 | 
						|
        bound = prepared_statement.bind(None)
 | 
						|
        self.assertListEqual(bound.values, [])
 | 
						|
 | 
						|
    def test_bind_none(self):
 | 
						|
        self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None})
 | 
						|
        self.assertEqual(self.bound.values[-1], None)
 | 
						|
 | 
						|
        old_values = self.bound.values
 | 
						|
        self.bound.bind((0, 0, 0, None))
 | 
						|
        self.assertIsNot(self.bound.values, old_values)
 | 
						|
        self.assertEqual(self.bound.values[-1], None)
 | 
						|
 | 
						|
    def test_unset_value(self):
 | 
						|
        self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
 | 
						|
        self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, UNSET_VALUE))
 | 
						|
 | 
						|
 | 
						|
class BoundStatementTestV2(BoundStatementTestV1):
 | 
						|
    protocol_version=2
 | 
						|
 | 
						|
 | 
						|
class BoundStatementTestV3(BoundStatementTestV1):
 | 
						|
    protocol_version=3
 | 
						|
 | 
						|
 | 
						|
class BoundStatementTestV4(BoundStatementTestV1):
 | 
						|
    protocol_version=4
 | 
						|
 | 
						|
    def test_dict_missing_routing_key(self):
 | 
						|
        # in v4 it implicitly binds UNSET_VALUE for missing items,
 | 
						|
        # UNSET_VALUE is ValueError for routing keys
 | 
						|
        self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0})
 | 
						|
        self.assertRaises(ValueError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0})
 | 
						|
 | 
						|
    def test_missing_value(self):
 | 
						|
        # in v4 missing values are UNSET_VALUE
 | 
						|
        self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0})
 | 
						|
        self.assertEqual(self.bound.values[-1], UNSET_VALUE)
 | 
						|
 | 
						|
        old_values = self.bound.values
 | 
						|
        self.bound.bind((0, 0, 0))
 | 
						|
        self.assertIsNot(self.bound.values, old_values)
 | 
						|
        self.assertEqual(self.bound.values[-1], UNSET_VALUE)
 | 
						|
 | 
						|
    def test_unset_value(self):
 | 
						|
        self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
 | 
						|
        self.assertEqual(self.bound.values[-1], UNSET_VALUE)
 | 
						|
 | 
						|
        old_values = self.bound.values
 | 
						|
        self.bound.bind((0, 0, 0, UNSET_VALUE))
 | 
						|
        self.assertEqual(self.bound.values[-1], UNSET_VALUE)
 |