Files
deb-python-cassandra-driver/tests/unit/test_parameter_binding.py
Tyler Hobbs 8bee26d6d6 Handle custom encoders in nested data types
The root of the problem was that nested data types would use the default
encoders for subitems.  When the encoders were customized, they would
not be used for those nested items.

This fix moves the encoder functions into a class so that collections,
tuples, and UDTs will use the customized mapping when encoding subitems.

Fixes PYTHON-100.
2014-07-22 14:21:04 -05:00

131 lines
4.8 KiB
Python

# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from cassandra.encoder import default_encoder
from cassandra.query import bind_params, ValueSequence
from cassandra.query import PreparedStatement, BoundStatement
from cassandra.cqltypes import Int32Type
from cassandra.util import OrderedDict
from six.moves import xrange
class ParamBindingTest(unittest.TestCase):
def test_bind_sequence(self):
result = bind_params("%s %s %s", (1, "a", 2.0), default_encoder)
self.assertEqual(result, "1 'a' 2.0")
def test_bind_map(self):
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), default_encoder)
self.assertEqual(result, "1 'a' 2.0")
def test_sequence_param(self):
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), default_encoder)
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
def test_generator_param(self):
result = bind_params("%s", ((i for i in xrange(3)),), default_encoder)
self.assertEqual(result, "[ 0 , 1 , 2 ]")
def test_none_param(self):
result = bind_params("%s", (None,), default_encoder)
self.assertEqual(result, "NULL")
def test_list_collection(self):
result = bind_params("%s", (['a', 'b', 'c'],), default_encoder)
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
def test_set_collection(self):
result = bind_params("%s", (set(['a', 'b']),), default_encoder)
self.assertIn(result, ("{ 'a' , 'b' }", "{ 'b' , 'a' }"))
def test_map_collection(self):
vals = OrderedDict()
vals['a'] = 'a'
vals['b'] = 'b'
vals['c'] = 'c'
result = bind_params("%s", (vals,), default_encoder)
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
def test_quote_escaping(self):
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), default_encoder)
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")
class BoundStatementTestCase(unittest.TestCase):
def test_invalid_argument_type(self):
keyspace = 'keyspace1'
column_family = 'cf1'
column_metadata = [
(keyspace, column_family, 'foo1', Int32Type),
(keyspace, column_family, 'foo2', Int32Type)
]
prepared_statement = PreparedStatement(column_metadata=column_metadata,
query_id=None,
routing_key_indexes=[],
query=None,
keyspace=keyspace,
protocol_version=2)
bound_statement = BoundStatement(prepared_statement=prepared_statement)
values = ['nonint', 1]
try:
bound_statement.bind(values)
except TypeError as e:
self.assertIn('foo1', str(e))
self.assertIn('Int32Type', str(e))
self.assertIn('str', str(e))
else:
self.fail('Passed invalid type but exception was not thrown')
values = [1, ['1', '2']]
try:
bound_statement.bind(values)
except TypeError as e:
self.assertIn('foo2', str(e))
self.assertIn('Int32Type', str(e))
self.assertIn('list', str(e))
else:
self.fail('Passed invalid type but exception was not thrown')
def test_inherit_fetch_size(self):
keyspace = 'keyspace1'
column_family = 'cf1'
column_metadata = [
(keyspace, column_family, 'foo1', Int32Type),
(keyspace, column_family, 'foo2', Int32Type)
]
prepared_statement = PreparedStatement(column_metadata=column_metadata,
query_id=None,
routing_key_indexes=[],
query=None,
keyspace=keyspace,
protocol_version=2,
fetch_size=1234)
bound_statement = BoundStatement(prepared_statement=prepared_statement)
self.assertEqual(1234, bound_statement.fetch_size)