Handle custom encoders in nested data types

The root of the problem was that nested data types would use the default
encoders for subitems.  When the encoders were customized, they would
not be used for those nested items.

This fix moves the encoder functions into a class so that collections,
tuples, and UDTs will use the customized mapping when encoding subitems.

Fixes PYTHON-100.
This commit is contained in:
Tyler Hobbs
2014-07-22 14:21:04 -05:00
parent 9121cb1b2f
commit 8bee26d6d6
9 changed files with 213 additions and 224 deletions

View File

@@ -44,7 +44,7 @@ from itertools import groupby
from cassandra import (ConsistencyLevel, AuthenticationFailed, from cassandra import (ConsistencyLevel, AuthenticationFailed,
InvalidRequest, OperationTimedOut, UnsupportedOperation) InvalidRequest, OperationTimedOut, UnsupportedOperation)
from cassandra.connection import ConnectionException, ConnectionShutdown from cassandra.connection import ConnectionException, ConnectionShutdown
from cassandra.encoder import cql_encode_all_types, cql_encoders from cassandra.encoder import Encoder
from cassandra.protocol import (QueryMessage, ResultMessage, from cassandra.protocol import (QueryMessage, ResultMessage,
ErrorMessage, ReadTimeoutErrorMessage, ErrorMessage, ReadTimeoutErrorMessage,
WriteTimeoutErrorMessage, WriteTimeoutErrorMessage,
@@ -1162,25 +1162,25 @@ class Session(object):
.. versionadded:: 2.1.0 .. versionadded:: 2.1.0
""" """
encoders = None encoder = None
""" """
A map of python types to CQL encoder functions that will be used when A :class:`~cassandra.encoder.Encoder` instance that will be used when
formatting query parameters for non-prepared statements. This mapping formatting query parameters for non-prepared statements. This is not used
is not used for prepared statements (because prepared statements for prepared statements (because prepared statements give the driver more
give the driver more information about what CQL types are expected, allowing information about what CQL types are expected, allowing it to accept a
it to accept a wider range of python types). wider range of python types).
This mapping can be be modified by users as they see fit. Functions from The encoder uses a mapping from python types to encoder methods (for
:mod:`cassandra.encoder` should be used, if possible, because they take specific CQL types). This mapping can be be modified by users as they see
precautions to avoid injections and properly sanitize data. fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping
values if possible, because they take precautions to avoid injections and
properly sanitize data.
Example:: Example::
from cassandra.encoder import cql_encode_tuple
cluster = Cluster() cluster = Cluster()
session = cluster.connect("mykeyspace") session = cluster.connect("mykeyspace")
session.encoders[tuple] = cql_encode_tuple session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple
session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple<int, ascii>)") session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple<int, ascii>)")
session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')]) session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')])
@@ -1202,7 +1202,7 @@ class Session(object):
self._metrics = cluster.metrics self._metrics = cluster.metrics
self._protocol_version = self.cluster.protocol_version self._protocol_version = self.cluster.protocol_version
self.encoders = cql_encoders.copy() self.encoder = Encoder()
# create connection pools in parallel # create connection pools in parallel
futures = [] futures = []
@@ -1328,7 +1328,7 @@ class Session(object):
if six.PY2 and isinstance(query_string, six.text_type): if six.PY2 and isinstance(query_string, six.text_type):
query_string = query_string.encode('utf-8') query_string = query_string.encode('utf-8')
if parameters: if parameters:
query_string = bind_params(query_string, parameters, self.encoders) query_string = bind_params(query_string, parameters, self.encoder)
message = QueryMessage( message = QueryMessage(
query_string, cl, query.serial_consistency_level, query_string, cl, query.serial_consistency_level,
fetch_size, timestamp=timestamp) fetch_size, timestamp=timestamp)
@@ -1585,13 +1585,13 @@ class Session(object):
raise UserTypeDoesNotExist( raise UserTypeDoesNotExist(
'User type %s does not exist in keyspace %s' % (user_type, keyspace)) 'User type %s does not exist in keyspace %s' % (user_type, keyspace))
def encode(val): def encode(encoder_self, val):
return '{ %s }' % ' , '.join('%s : %s' % ( return '{ %s }' % ' , '.join('%s : %s' % (
field_name, field_name,
cql_encode_all_types(getattr(val, field_name)) encoder_self.cql_encode_all_types(getattr(val, field_name))
) for field_name in type_meta.field_names) ) for field_name in type_meta.field_names)
self.encoders[klass] = encode self.encoder.mapping[klass] = encode
def submit(self, fn, *args, **kwargs): def submit(self, fn, *args, **kwargs):
""" Internal """ """ Internal """

View File

@@ -49,154 +49,156 @@ def cql_quote(term):
return str(term) return str(term)
def cql_encode_none(val): class ValueSequence(list):
pass
class Encoder(object):
""" """
Converts :const:`None` to the string 'NULL'. A container for mapping python types to CQL string literals when working
with non-prepared statements. The type :attr:`~.Encoder.mapping` can be
directly customized by users.
""" """
return 'NULL'
mapping = None
def cql_encode_unicode(val):
""" """
Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping. A map of python types to encoder functions.
""" """
return cql_quote(val.encode('utf-8'))
def __init__(self):
self.mapping = {
float: self.cql_encode_object,
bytearray: self.cql_encode_bytes,
str: self.cql_encode_str,
int: self.cql_encode_object,
UUID: self.cql_encode_object,
datetime.datetime: self.cql_encode_datetime,
datetime.date: self.cql_encode_date,
dict: self.cql_encode_map_collection,
OrderedDict: self.cql_encode_map_collection,
list: self.cql_encode_list_collection,
tuple: self.cql_encode_list_collection,
set: self.cql_encode_set_collection,
frozenset: self.cql_encode_set_collection,
types.GeneratorType: self.cql_encode_list_collection,
ValueSequence: self.cql_encode_sequence
}
def cql_encode_str(val): if six.PY2:
""" self.mapping.update({
Escapes quotes in :class:`str` objects. unicode: self.cql_encode_unicode,
""" buffer: self.cql_encode_bytes,
return cql_quote(val) long: self.cql_encode_object,
types.NoneType: self.cql_encode_none,
})
else:
self.mapping.update({
memoryview: self.cql_encode_bytes,
bytes: self.cql_encode_bytes,
type(None): self.cql_encode_none,
})
# sortedset is optional
try:
from blist import sortedset
self.mapping.update({
sortedset: self.cql_encode_set_collection
})
except ImportError:
pass
if six.PY3: def cql_encode_none(self, val):
def cql_encode_bytes(val): """
return (b'0x' + hexlify(val)).decode('utf-8') Converts :const:`None` to the string 'NULL'.
elif sys.version_info >= (2, 7): """
def cql_encode_bytes(val): # noqa return 'NULL'
return b'0x' + hexlify(val)
else:
# python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(val): # noqa
return b'0x' + hexlify(buffer(val))
def cql_encode_unicode(self, val):
"""
Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping.
"""
return cql_quote(val.encode('utf-8'))
def cql_encode_object(val): def cql_encode_str(self, val):
""" """
Default encoder for all objects that do not have a specific encoder function Escapes quotes in :class:`str` objects.
registered. This function simply calls :meth:`str()` on the object. """
""" return cql_quote(val)
return str(val)
if six.PY3:
def cql_encode_bytes(self, val):
return (b'0x' + hexlify(val)).decode('utf-8')
elif sys.version_info >= (2, 7):
def cql_encode_bytes(self, val): # noqa
return b'0x' + hexlify(val)
else:
# python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(self, val): # noqa
return b'0x' + hexlify(buffer(val))
def cql_encode_datetime(val): def cql_encode_object(self, val):
""" """
Converts a :class:`datetime.datetime` object to a (string) integer timestamp Default encoder for all objects that do not have a specific encoder function
with millisecond precision. registered. This function simply calls :meth:`str()` on the object.
""" """
timestamp = calendar.timegm(val.utctimetuple()) return str(val)
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
def cql_encode_datetime(self, val):
"""
Converts a :class:`datetime.datetime` object to a (string) integer timestamp
with millisecond precision.
"""
timestamp = calendar.timegm(val.utctimetuple())
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
def cql_encode_date(val): def cql_encode_date(self, val):
""" """
Converts a :class:`datetime.date` object to a string with format Converts a :class:`datetime.date` object to a string with format
``YYYY-MM-DD-0000``. ``YYYY-MM-DD-0000``.
""" """
return "'%s'" % val.strftime('%Y-%m-%d-0000') return "'%s'" % val.strftime('%Y-%m-%d-0000')
def cql_encode_sequence(self, val):
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``IN`` value lists.
"""
return '( %s )' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v)
for v in val)
def cql_encode_sequence(val): cql_encode_tuple = cql_encode_sequence
""" """
Converts a sequence to a string of the form ``(item1, item2, ...)``. This Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``IN`` value lists. is suitable for ``tuple`` type columns.
""" """
return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v)
for v in val)
def cql_encode_map_collection(self, val):
"""
Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``.
This is suitable for ``map`` type columns.
"""
return '{ %s }' % ' , '.join('%s : %s' % (
self.mapping.get(type(k), self.cql_encode_object)(k),
self.mapping.get(type(v), self.cql_encode_object)(v)
) for k, v in six.iteritems(val))
cql_encode_tuple = cql_encode_sequence def cql_encode_list_collection(self, val):
""" """
Converts a sequence to a string of the form ``(item1, item2, ...)``. This Converts a sequence to a string of the form ``[item1, item2, ...]``. This
is suitable for ``tuple`` type columns. is suitable for ``list`` type columns.
""" """
return '[ %s ]' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
def cql_encode_set_collection(self, val):
"""
Converts a sequence to a string of the form ``{item1, item2, ...}``. This
is suitable for ``set`` type columns.
"""
return '{ %s }' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
def cql_encode_map_collection(val): def cql_encode_all_types(self, val):
""" """
Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``. Converts any type into a CQL string, defaulting to ``cql_encode_object``
This is suitable for ``map`` type columns. if :attr:`~Encoder.mapping` does not contain an entry for the type.
""" """
return '{ %s }' % ' , '.join('%s : %s' % ( return self.mapping.get(type(val), self.cql_encode_object)(val)
cql_encode_all_types(k),
cql_encode_all_types(v)
) for k, v in six.iteritems(val))
def cql_encode_list_collection(val):
"""
Converts a sequence to a string of the form ``[item1, item2, ...]``. This
is suitable for ``list`` type columns.
"""
return '[ %s ]' % ' , '.join(map(cql_encode_all_types, val))
def cql_encode_set_collection(val):
"""
Converts a sequence to a string of the form ``{item1, item2, ...}``. This
is suitable for ``set`` type columns.
"""
return '{ %s }' % ' , '.join(map(cql_encode_all_types, val))
def cql_encode_all_types(val):
"""
Converts any type into a CQL string, defaulting to ``cql_encode_object``
if :attr:`~.cql_encoders` does not contain an entry for the type.
"""
return cql_encoders.get(type(val), cql_encode_object)(val)
cql_encoders = {
float: cql_encode_object,
bytearray: cql_encode_bytes,
str: cql_encode_str,
int: cql_encode_object,
UUID: cql_encode_object,
datetime.datetime: cql_encode_datetime,
datetime.date: cql_encode_date,
dict: cql_encode_map_collection,
OrderedDict: cql_encode_map_collection,
list: cql_encode_list_collection,
tuple: cql_encode_list_collection,
set: cql_encode_set_collection,
frozenset: cql_encode_set_collection,
types.GeneratorType: cql_encode_list_collection
}
"""
A map of python types to encoder functions.
"""
if six.PY2:
cql_encoders.update({
unicode: cql_encode_unicode,
buffer: cql_encode_bytes,
long: cql_encode_object,
types.NoneType: cql_encode_none,
})
else:
cql_encoders.update({
memoryview: cql_encode_bytes,
bytes: cql_encode_bytes,
type(None): cql_encode_none,
})
# sortedset is optional
try:
from blist import sortedset
cql_encoders.update({
sortedset: cql_encode_set_collection
})
except ImportError:
pass

View File

@@ -27,8 +27,8 @@ import six
from cassandra import ConsistencyLevel, OperationTimedOut from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cqltypes import unix_time_from_uuid1 from cassandra.cqltypes import unix_time_from_uuid1
from cassandra.encoder import (cql_encoders, cql_encode_object, from cassandra.encoder import Encoder
cql_encode_sequence) import cassandra.encoder
from cassandra.util import OrderedDict from cassandra.util import OrderedDict
import logging import logging
@@ -625,8 +625,8 @@ class BatchStatement(Statement):
""" """
if isinstance(statement, six.string_types): if isinstance(statement, six.string_types):
if parameters: if parameters:
encoders = cql_encoders if self._session is None else self._session.encoders encoder = Encoder() if self._session is None else self._session.encoder
statement = bind_params(statement, parameters, encoders) statement = bind_params(statement, parameters, encoder)
self._statements_and_parameters.append((False, statement, ())) self._statements_and_parameters.append((False, statement, ()))
elif isinstance(statement, PreparedStatement): elif isinstance(statement, PreparedStatement):
query_id = statement.query_id query_id = statement.query_id
@@ -644,8 +644,8 @@ class BatchStatement(Statement):
# it must be a SimpleStatement # it must be a SimpleStatement
query_string = statement.query_string query_string = statement.query_string
if parameters: if parameters:
encoders = cql_encoders if self._session is None else self._session.encoders encoder = Encoder() if self._session is None else self._session.encoder
query_string = bind_params(query_string, parameters, encoders) query_string = bind_params(query_string, parameters, encoder)
self._statements_and_parameters.append((False, query_string, ())) self._statements_and_parameters.append((False, query_string, ()))
return self return self
@@ -665,33 +665,27 @@ class BatchStatement(Statement):
__repr__ = __str__ __repr__ = __str__
class ValueSequence(object): ValueSequence = cassandra.encoder.ValueSequence
""" """
A wrapper class that is used to specify that a sequence of values should A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`. as part of the `parameters` argument for :meth:`.Session.execute()`.
This is typically needed when supplying a list of keys to select. This is typically needed when supplying a list of keys to select.
For example:: For example::
>>> my_user_ids = ('alice', 'bob', 'charles') >>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s" >>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)]) >>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
""" """
def __init__(self, sequence):
self.sequence = sequence
def __str__(self):
return cql_encode_sequence(self.sequence)
def bind_params(query, params, encoders): def bind_params(query, params, encoder):
if isinstance(params, dict): if isinstance(params, dict):
return query % dict((k, encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params)) return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params))
else: else:
return query % tuple(encoders.get(type(v), cql_encode_object)(v) for v in params) return query % tuple(encoder.cql_encode_all_types(v) for v in params)
class TraceUnavailable(Exception): class TraceUnavailable(Exception):

View File

@@ -67,7 +67,7 @@
.. autoattribute:: use_client_timestamp .. autoattribute:: use_client_timestamp
.. autoattribute:: encoders .. autoattribute:: encoder
.. automethod:: execute(statement[, parameters][, timeout][, trace]) .. automethod:: execute(statement[, parameters][, timeout][, trace])

View File

@@ -3,41 +3,34 @@
.. module:: cassandra.encoder .. module:: cassandra.encoder
.. data:: cql_encoders .. autoclass:: Encoder ()
A map of python types to encoder functions. .. autoattribute:: cassandra.encoder.Encoder.mapping
.. autofunction:: cql_encode_none () .. automethod:: cassandra.encoder.Encoder.cql_encode_none ()
.. autofunction:: cql_encode_object () .. automethod:: cassandra.encoder.Encoder.cql_encode_object ()
.. autofunction:: cql_encode_all_types () .. automethod:: cassandra.encoder.Encoder.cql_encode_all_types ()
.. autofunction:: cql_encode_sequence () .. automethod:: cassandra.encoder.Encoder.cql_encode_sequence ()
String Types .. automethod:: cassandra.encoder.Encoder.cql_encode_str ()
------------
.. autofunction:: cql_encode_str () .. automethod:: cassandra.encoder.Encoder.cql_encode_unicode ()
.. autofunction:: cql_encode_unicode () .. automethod:: cassandra.encoder.Encoder.cql_encode_bytes ()
.. autofunction:: cql_encode_bytes () Converts strings, buffers, and bytearrays into CQL blob literals.
Date Types .. automethod:: cassandra.encoder.Encoder.cql_encode_datetime ()
----------
.. autofunction:: cql_encode_datetime () .. automethod:: cassandra.encoder.Encoder.cql_encode_date ()
.. autofunction:: cql_encode_date () .. automethod:: cassandra.encoder.Encoder.cql_encode_map_collection ()
Collection Types .. automethod:: cassandra.encoder.Encoder.cql_encode_list_collection ()
----------------
.. autofunction:: cql_encode_map_collection () .. automethod:: cassandra.encoder.Encoder.cql_encode_set_collection ()
.. autofunction:: cql_encode_list_collection () .. automethod:: cql_encode_tuple ()
.. autofunction:: cql_encode_set_collection ()
.. autofunction:: cql_encode_tuple ()

View File

@@ -34,8 +34,18 @@
.. autoattribute:: COUNTER .. autoattribute:: COUNTER
.. autoclass:: ValueSequence .. autoclass:: cassandra.query.ValueSequence
:members:
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
This is typically needed when supplying a list of keys to select.
For example::
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
.. autoclass:: QueryTrace () .. autoclass:: QueryTrace ()
:members: :members:

View File

@@ -18,9 +18,8 @@ except ImportError:
import unittest # noqa import unittest # noqa
from cassandra import ConsistencyLevel from cassandra import ConsistencyLevel
from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence, from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement,
SimpleStatement, BatchStatement, BatchType, BatchStatement, BatchType, dict_factory)
dict_factory)
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.policies import HostDistance from cassandra.policies import HostDistance
@@ -45,14 +44,6 @@ class QueryTest(unittest.TestCase):
session.execute(bound) session.execute(bound)
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01') self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_value_sequence(self):
"""
Test the output of ValueSequences()
"""
my_user_ids = ('alice', 'bob', 'charles')
self.assertEqual(str(ValueSequence(my_user_ids)), "( 'alice' , 'bob' , 'charles' )")
def test_trace_prints_okay(self): def test_trace_prints_okay(self):
""" """
Code coverage to ensure trace prints to string without error Code coverage to ensure trace prints to string without error

View File

@@ -34,7 +34,6 @@ except ImportError:
from cassandra import InvalidRequest from cassandra import InvalidRequest
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.cqltypes import Int32Type, EMPTY from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.encoder import cql_encode_tuple
from cassandra.query import dict_factory from cassandra.query import dict_factory
from cassandra.util import OrderedDict from cassandra.util import OrderedDict
@@ -416,7 +415,7 @@ class TypeTests(unittest.TestCase):
s = c.connect() s = c.connect()
# use this encoder in order to insert tuples # use this encoder in order to insert tuples
s.encoders[tuple] = cql_encode_tuple s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_type s.execute("""CREATE KEYSPACE test_tuple_type
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""") WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
@@ -468,7 +467,7 @@ class TypeTests(unittest.TestCase):
# set the row_factory to dict_factory for programmatic access # set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples # set the encoder for tuples for the ability to write tuples
s.row_factory = dict_factory s.row_factory = dict_factory
s.encoders[tuple] = cql_encode_tuple s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_type_varying_lengths s.execute("""CREATE KEYSPACE test_tuple_type_varying_lengths
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""") WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
@@ -501,7 +500,7 @@ class TypeTests(unittest.TestCase):
c = Cluster(protocol_version=PROTOCOL_VERSION) c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect() s = c.connect()
s.encoders[tuple] = cql_encode_tuple s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_subtypes s.execute("""CREATE KEYSPACE test_tuple_subtypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""") WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
@@ -541,7 +540,7 @@ class TypeTests(unittest.TestCase):
if depth == 0: if depth == 0:
return 303 return 303
else: else:
return tuple((self.nested_tuples_creator_helper(depth - 1),)) return (self.nested_tuples_creator_helper(depth - 1), )
def test_nested_tuples(self): def test_nested_tuples(self):
""" """
@@ -557,7 +556,7 @@ class TypeTests(unittest.TestCase):
# set the row_factory to dict_factory for programmatic access # set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples # set the encoder for tuples for the ability to write tuples
s.row_factory = dict_factory s.row_factory = dict_factory
s.encoders[tuple] = cql_encode_tuple s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_nested_tuples s.execute("""CREATE KEYSPACE test_nested_tuples
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""") WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")

View File

@@ -17,7 +17,7 @@ try:
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
from cassandra.encoder import cql_encoders from cassandra.encoder import default_encoder
from cassandra.query import bind_params, ValueSequence from cassandra.query import bind_params, ValueSequence
from cassandra.query import PreparedStatement, BoundStatement from cassandra.query import PreparedStatement, BoundStatement
from cassandra.cqltypes import Int32Type from cassandra.cqltypes import Int32Type
@@ -29,31 +29,31 @@ from six.moves import xrange
class ParamBindingTest(unittest.TestCase): class ParamBindingTest(unittest.TestCase):
def test_bind_sequence(self): def test_bind_sequence(self):
result = bind_params("%s %s %s", (1, "a", 2.0), cql_encoders) result = bind_params("%s %s %s", (1, "a", 2.0), default_encoder)
self.assertEqual(result, "1 'a' 2.0") self.assertEqual(result, "1 'a' 2.0")
def test_bind_map(self): def test_bind_map(self):
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), cql_encoders) result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), default_encoder)
self.assertEqual(result, "1 'a' 2.0") self.assertEqual(result, "1 'a' 2.0")
def test_sequence_param(self): def test_sequence_param(self):
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), cql_encoders) result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), default_encoder)
self.assertEqual(result, "( 1 , 'a' , 2.0 )") self.assertEqual(result, "( 1 , 'a' , 2.0 )")
def test_generator_param(self): def test_generator_param(self):
result = bind_params("%s", ((i for i in xrange(3)),), cql_encoders) result = bind_params("%s", ((i for i in xrange(3)),), default_encoder)
self.assertEqual(result, "[ 0 , 1 , 2 ]") self.assertEqual(result, "[ 0 , 1 , 2 ]")
def test_none_param(self): def test_none_param(self):
result = bind_params("%s", (None,), cql_encoders) result = bind_params("%s", (None,), default_encoder)
self.assertEqual(result, "NULL") self.assertEqual(result, "NULL")
def test_list_collection(self): def test_list_collection(self):
result = bind_params("%s", (['a', 'b', 'c'],), cql_encoders) result = bind_params("%s", (['a', 'b', 'c'],), default_encoder)
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]") self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
def test_set_collection(self): def test_set_collection(self):
result = bind_params("%s", (set(['a', 'b']),), cql_encoders) result = bind_params("%s", (set(['a', 'b']),), default_encoder)
self.assertIn(result, ("{ 'a' , 'b' }", "{ 'b' , 'a' }")) self.assertIn(result, ("{ 'a' , 'b' }", "{ 'b' , 'a' }"))
def test_map_collection(self): def test_map_collection(self):
@@ -61,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
vals['a'] = 'a' vals['a'] = 'a'
vals['b'] = 'b' vals['b'] = 'b'
vals['c'] = 'c' vals['c'] = 'c'
result = bind_params("%s", (vals,), cql_encoders) result = bind_params("%s", (vals,), default_encoder)
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }") self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
def test_quote_escaping(self): def test_quote_escaping(self):
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), cql_encoders) result = bind_params("%s", ("""'ef''ef"ef""ef'""",), default_encoder)
self.assertEqual(result, """'''ef''''ef"ef""ef'''""") self.assertEqual(result, """'''ef''''ef"ef""ef'''""")