Handle custom encoders in nested data types

The root of the problem was that nested data types would use the default
encoders for subitems.  When the encoders were customized, they would
not be used for those nested items.

This fix moves the encoder functions into a class so that collections,
tuples, and UDTs will use the customized mapping when encoding subitems.

Fixes PYTHON-100.
This commit is contained in:
Tyler Hobbs
2014-07-22 14:21:04 -05:00
parent 9121cb1b2f
commit 8bee26d6d6
9 changed files with 213 additions and 224 deletions

View File

@@ -44,7 +44,7 @@ from itertools import groupby
from cassandra import (ConsistencyLevel, AuthenticationFailed,
InvalidRequest, OperationTimedOut, UnsupportedOperation)
from cassandra.connection import ConnectionException, ConnectionShutdown
from cassandra.encoder import cql_encode_all_types, cql_encoders
from cassandra.encoder import Encoder
from cassandra.protocol import (QueryMessage, ResultMessage,
ErrorMessage, ReadTimeoutErrorMessage,
WriteTimeoutErrorMessage,
@@ -1162,25 +1162,25 @@ class Session(object):
.. versionadded:: 2.1.0
"""
encoders = None
encoder = None
"""
A map of python types to CQL encoder functions that will be used when
formatting query parameters for non-prepared statements. This mapping
is not used for prepared statements (because prepared statements
give the driver more information about what CQL types are expected, allowing
it to accept a wider range of python types).
A :class:`~cassandra.encoder.Encoder` instance that will be used when
formatting query parameters for non-prepared statements. This is not used
for prepared statements (because prepared statements give the driver more
information about what CQL types are expected, allowing it to accept a
wider range of python types).
This mapping can be be modified by users as they see fit. Functions from
:mod:`cassandra.encoder` should be used, if possible, because they take
precautions to avoid injections and properly sanitize data.
The encoder uses a mapping from python types to encoder methods (for
specific CQL types). This mapping can be be modified by users as they see
fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping
values if possible, because they take precautions to avoid injections and
properly sanitize data.
Example::
from cassandra.encoder import cql_encode_tuple
cluster = Cluster()
session = cluster.connect("mykeyspace")
session.encoders[tuple] = cql_encode_tuple
session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple
session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple<int, ascii>)")
session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')])
@@ -1202,7 +1202,7 @@ class Session(object):
self._metrics = cluster.metrics
self._protocol_version = self.cluster.protocol_version
self.encoders = cql_encoders.copy()
self.encoder = Encoder()
# create connection pools in parallel
futures = []
@@ -1328,7 +1328,7 @@ class Session(object):
if six.PY2 and isinstance(query_string, six.text_type):
query_string = query_string.encode('utf-8')
if parameters:
query_string = bind_params(query_string, parameters, self.encoders)
query_string = bind_params(query_string, parameters, self.encoder)
message = QueryMessage(
query_string, cl, query.serial_consistency_level,
fetch_size, timestamp=timestamp)
@@ -1585,13 +1585,13 @@ class Session(object):
raise UserTypeDoesNotExist(
'User type %s does not exist in keyspace %s' % (user_type, keyspace))
def encode(val):
def encode(encoder_self, val):
return '{ %s }' % ' , '.join('%s : %s' % (
field_name,
cql_encode_all_types(getattr(val, field_name))
encoder_self.cql_encode_all_types(getattr(val, field_name))
) for field_name in type_meta.field_names)
self.encoders[klass] = encode
self.encoder.mapping[klass] = encode
def submit(self, fn, *args, **kwargs):
""" Internal """

View File

@@ -49,154 +49,156 @@ def cql_quote(term):
return str(term)
def cql_encode_none(val):
class ValueSequence(list):
pass
class Encoder(object):
"""
Converts :const:`None` to the string 'NULL'.
A container for mapping python types to CQL string literals when working
with non-prepared statements. The type :attr:`~.Encoder.mapping` can be
directly customized by users.
"""
return 'NULL'
def cql_encode_unicode(val):
mapping = None
"""
Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping.
A map of python types to encoder functions.
"""
return cql_quote(val.encode('utf-8'))
def __init__(self):
self.mapping = {
float: self.cql_encode_object,
bytearray: self.cql_encode_bytes,
str: self.cql_encode_str,
int: self.cql_encode_object,
UUID: self.cql_encode_object,
datetime.datetime: self.cql_encode_datetime,
datetime.date: self.cql_encode_date,
dict: self.cql_encode_map_collection,
OrderedDict: self.cql_encode_map_collection,
list: self.cql_encode_list_collection,
tuple: self.cql_encode_list_collection,
set: self.cql_encode_set_collection,
frozenset: self.cql_encode_set_collection,
types.GeneratorType: self.cql_encode_list_collection,
ValueSequence: self.cql_encode_sequence
}
def cql_encode_str(val):
"""
Escapes quotes in :class:`str` objects.
"""
return cql_quote(val)
if six.PY2:
self.mapping.update({
unicode: self.cql_encode_unicode,
buffer: self.cql_encode_bytes,
long: self.cql_encode_object,
types.NoneType: self.cql_encode_none,
})
else:
self.mapping.update({
memoryview: self.cql_encode_bytes,
bytes: self.cql_encode_bytes,
type(None): self.cql_encode_none,
})
# sortedset is optional
try:
from blist import sortedset
self.mapping.update({
sortedset: self.cql_encode_set_collection
})
except ImportError:
pass
if six.PY3:
def cql_encode_bytes(val):
return (b'0x' + hexlify(val)).decode('utf-8')
elif sys.version_info >= (2, 7):
def cql_encode_bytes(val): # noqa
return b'0x' + hexlify(val)
else:
# python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(val): # noqa
return b'0x' + hexlify(buffer(val))
def cql_encode_none(self, val):
"""
Converts :const:`None` to the string 'NULL'.
"""
return 'NULL'
def cql_encode_unicode(self, val):
"""
Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping.
"""
return cql_quote(val.encode('utf-8'))
def cql_encode_object(val):
"""
Default encoder for all objects that do not have a specific encoder function
registered. This function simply calls :meth:`str()` on the object.
"""
return str(val)
def cql_encode_str(self, val):
"""
Escapes quotes in :class:`str` objects.
"""
return cql_quote(val)
if six.PY3:
def cql_encode_bytes(self, val):
return (b'0x' + hexlify(val)).decode('utf-8')
elif sys.version_info >= (2, 7):
def cql_encode_bytes(self, val): # noqa
return b'0x' + hexlify(val)
else:
# python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(self, val): # noqa
return b'0x' + hexlify(buffer(val))
def cql_encode_datetime(val):
"""
Converts a :class:`datetime.datetime` object to a (string) integer timestamp
with millisecond precision.
"""
timestamp = calendar.timegm(val.utctimetuple())
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
def cql_encode_object(self, val):
"""
Default encoder for all objects that do not have a specific encoder function
registered. This function simply calls :meth:`str()` on the object.
"""
return str(val)
def cql_encode_datetime(self, val):
"""
Converts a :class:`datetime.datetime` object to a (string) integer timestamp
with millisecond precision.
"""
timestamp = calendar.timegm(val.utctimetuple())
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
def cql_encode_date(val):
"""
Converts a :class:`datetime.date` object to a string with format
``YYYY-MM-DD-0000``.
"""
return "'%s'" % val.strftime('%Y-%m-%d-0000')
def cql_encode_date(self, val):
"""
Converts a :class:`datetime.date` object to a string with format
``YYYY-MM-DD-0000``.
"""
return "'%s'" % val.strftime('%Y-%m-%d-0000')
def cql_encode_sequence(self, val):
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``IN`` value lists.
"""
return '( %s )' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v)
for v in val)
def cql_encode_sequence(val):
cql_encode_tuple = cql_encode_sequence
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``IN`` value lists.
is suitable for ``tuple`` type columns.
"""
return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v)
for v in val)
def cql_encode_map_collection(self, val):
"""
Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``.
This is suitable for ``map`` type columns.
"""
return '{ %s }' % ' , '.join('%s : %s' % (
self.mapping.get(type(k), self.cql_encode_object)(k),
self.mapping.get(type(v), self.cql_encode_object)(v)
) for k, v in six.iteritems(val))
cql_encode_tuple = cql_encode_sequence
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``tuple`` type columns.
"""
def cql_encode_list_collection(self, val):
"""
Converts a sequence to a string of the form ``[item1, item2, ...]``. This
is suitable for ``list`` type columns.
"""
return '[ %s ]' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
def cql_encode_set_collection(self, val):
"""
Converts a sequence to a string of the form ``{item1, item2, ...}``. This
is suitable for ``set`` type columns.
"""
return '{ %s }' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
def cql_encode_map_collection(val):
"""
Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``.
This is suitable for ``map`` type columns.
"""
return '{ %s }' % ' , '.join('%s : %s' % (
cql_encode_all_types(k),
cql_encode_all_types(v)
) for k, v in six.iteritems(val))
def cql_encode_list_collection(val):
"""
Converts a sequence to a string of the form ``[item1, item2, ...]``. This
is suitable for ``list`` type columns.
"""
return '[ %s ]' % ' , '.join(map(cql_encode_all_types, val))
def cql_encode_set_collection(val):
"""
Converts a sequence to a string of the form ``{item1, item2, ...}``. This
is suitable for ``set`` type columns.
"""
return '{ %s }' % ' , '.join(map(cql_encode_all_types, val))
def cql_encode_all_types(val):
"""
Converts any type into a CQL string, defaulting to ``cql_encode_object``
if :attr:`~.cql_encoders` does not contain an entry for the type.
"""
return cql_encoders.get(type(val), cql_encode_object)(val)
cql_encoders = {
float: cql_encode_object,
bytearray: cql_encode_bytes,
str: cql_encode_str,
int: cql_encode_object,
UUID: cql_encode_object,
datetime.datetime: cql_encode_datetime,
datetime.date: cql_encode_date,
dict: cql_encode_map_collection,
OrderedDict: cql_encode_map_collection,
list: cql_encode_list_collection,
tuple: cql_encode_list_collection,
set: cql_encode_set_collection,
frozenset: cql_encode_set_collection,
types.GeneratorType: cql_encode_list_collection
}
"""
A map of python types to encoder functions.
"""
if six.PY2:
cql_encoders.update({
unicode: cql_encode_unicode,
buffer: cql_encode_bytes,
long: cql_encode_object,
types.NoneType: cql_encode_none,
})
else:
cql_encoders.update({
memoryview: cql_encode_bytes,
bytes: cql_encode_bytes,
type(None): cql_encode_none,
})
# sortedset is optional
try:
from blist import sortedset
cql_encoders.update({
sortedset: cql_encode_set_collection
})
except ImportError:
pass
def cql_encode_all_types(self, val):
"""
Converts any type into a CQL string, defaulting to ``cql_encode_object``
if :attr:`~Encoder.mapping` does not contain an entry for the type.
"""
return self.mapping.get(type(val), self.cql_encode_object)(val)

View File

@@ -27,8 +27,8 @@ import six
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cqltypes import unix_time_from_uuid1
from cassandra.encoder import (cql_encoders, cql_encode_object,
cql_encode_sequence)
from cassandra.encoder import Encoder
import cassandra.encoder
from cassandra.util import OrderedDict
import logging
@@ -625,8 +625,8 @@ class BatchStatement(Statement):
"""
if isinstance(statement, six.string_types):
if parameters:
encoders = cql_encoders if self._session is None else self._session.encoders
statement = bind_params(statement, parameters, encoders)
encoder = Encoder() if self._session is None else self._session.encoder
statement = bind_params(statement, parameters, encoder)
self._statements_and_parameters.append((False, statement, ()))
elif isinstance(statement, PreparedStatement):
query_id = statement.query_id
@@ -644,8 +644,8 @@ class BatchStatement(Statement):
# it must be a SimpleStatement
query_string = statement.query_string
if parameters:
encoders = cql_encoders if self._session is None else self._session.encoders
query_string = bind_params(query_string, parameters, encoders)
encoder = Encoder() if self._session is None else self._session.encoder
query_string = bind_params(query_string, parameters, encoder)
self._statements_and_parameters.append((False, query_string, ()))
return self
@@ -665,33 +665,27 @@ class BatchStatement(Statement):
__repr__ = __str__
class ValueSequence(object):
"""
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
ValueSequence = cassandra.encoder.ValueSequence
"""
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
This is typically needed when supplying a list of keys to select.
For example::
This is typically needed when supplying a list of keys to select.
For example::
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
"""
def __init__(self, sequence):
self.sequence = sequence
def __str__(self):
return cql_encode_sequence(self.sequence)
"""
def bind_params(query, params, encoders):
def bind_params(query, params, encoder):
if isinstance(params, dict):
return query % dict((k, encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params))
else:
return query % tuple(encoders.get(type(v), cql_encode_object)(v) for v in params)
return query % tuple(encoder.cql_encode_all_types(v) for v in params)
class TraceUnavailable(Exception):

View File

@@ -67,7 +67,7 @@
.. autoattribute:: use_client_timestamp
.. autoattribute:: encoders
.. autoattribute:: encoder
.. automethod:: execute(statement[, parameters][, timeout][, trace])

View File

@@ -3,41 +3,34 @@
.. module:: cassandra.encoder
.. data:: cql_encoders
.. autoclass:: Encoder ()
A map of python types to encoder functions.
.. autoattribute:: cassandra.encoder.Encoder.mapping
.. autofunction:: cql_encode_none ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_none ()
.. autofunction:: cql_encode_object ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_object ()
.. autofunction:: cql_encode_all_types ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_all_types ()
.. autofunction:: cql_encode_sequence ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_sequence ()
String Types
------------
.. automethod:: cassandra.encoder.Encoder.cql_encode_str ()
.. autofunction:: cql_encode_str ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_unicode ()
.. autofunction:: cql_encode_unicode ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_bytes ()
.. autofunction:: cql_encode_bytes ()
Converts strings, buffers, and bytearrays into CQL blob literals.
Date Types
----------
.. automethod:: cassandra.encoder.Encoder.cql_encode_datetime ()
.. autofunction:: cql_encode_datetime ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_date ()
.. autofunction:: cql_encode_date ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_map_collection ()
Collection Types
----------------
.. automethod:: cassandra.encoder.Encoder.cql_encode_list_collection ()
.. autofunction:: cql_encode_map_collection ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_set_collection ()
.. autofunction:: cql_encode_list_collection ()
.. autofunction:: cql_encode_set_collection ()
.. autofunction:: cql_encode_tuple ()
.. automethod:: cql_encode_tuple ()

View File

@@ -34,8 +34,18 @@
.. autoattribute:: COUNTER
.. autoclass:: ValueSequence
:members:
.. autoclass:: cassandra.query.ValueSequence
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
This is typically needed when supplying a list of keys to select.
For example::
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
.. autoclass:: QueryTrace ()
:members:

View File

@@ -18,9 +18,8 @@ except ImportError:
import unittest # noqa
from cassandra import ConsistencyLevel
from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
SimpleStatement, BatchStatement, BatchType,
dict_factory)
from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement,
BatchStatement, BatchType, dict_factory)
from cassandra.cluster import Cluster
from cassandra.policies import HostDistance
@@ -45,14 +44,6 @@ class QueryTest(unittest.TestCase):
session.execute(bound)
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_value_sequence(self):
"""
Test the output of ValueSequences()
"""
my_user_ids = ('alice', 'bob', 'charles')
self.assertEqual(str(ValueSequence(my_user_ids)), "( 'alice' , 'bob' , 'charles' )")
def test_trace_prints_okay(self):
"""
Code coverage to ensure trace prints to string without error

View File

@@ -34,7 +34,6 @@ except ImportError:
from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.encoder import cql_encode_tuple
from cassandra.query import dict_factory
from cassandra.util import OrderedDict
@@ -416,7 +415,7 @@ class TypeTests(unittest.TestCase):
s = c.connect()
# use this encoder in order to insert tuples
s.encoders[tuple] = cql_encode_tuple
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_type
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
@@ -468,7 +467,7 @@ class TypeTests(unittest.TestCase):
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
s.row_factory = dict_factory
s.encoders[tuple] = cql_encode_tuple
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_type_varying_lengths
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
@@ -501,7 +500,7 @@ class TypeTests(unittest.TestCase):
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.encoders[tuple] = cql_encode_tuple
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_subtypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
@@ -541,7 +540,7 @@ class TypeTests(unittest.TestCase):
if depth == 0:
return 303
else:
return tuple((self.nested_tuples_creator_helper(depth - 1),))
return (self.nested_tuples_creator_helper(depth - 1), )
def test_nested_tuples(self):
"""
@@ -557,7 +556,7 @@ class TypeTests(unittest.TestCase):
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
s.row_factory = dict_factory
s.encoders[tuple] = cql_encode_tuple
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_nested_tuples
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")

View File

@@ -17,7 +17,7 @@ try:
except ImportError:
import unittest # noqa
from cassandra.encoder import cql_encoders
from cassandra.encoder import default_encoder
from cassandra.query import bind_params, ValueSequence
from cassandra.query import PreparedStatement, BoundStatement
from cassandra.cqltypes import Int32Type
@@ -29,31 +29,31 @@ from six.moves import xrange
class ParamBindingTest(unittest.TestCase):
def test_bind_sequence(self):
result = bind_params("%s %s %s", (1, "a", 2.0), cql_encoders)
result = bind_params("%s %s %s", (1, "a", 2.0), default_encoder)
self.assertEqual(result, "1 'a' 2.0")
def test_bind_map(self):
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), cql_encoders)
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), default_encoder)
self.assertEqual(result, "1 'a' 2.0")
def test_sequence_param(self):
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), cql_encoders)
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), default_encoder)
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
def test_generator_param(self):
result = bind_params("%s", ((i for i in xrange(3)),), cql_encoders)
result = bind_params("%s", ((i for i in xrange(3)),), default_encoder)
self.assertEqual(result, "[ 0 , 1 , 2 ]")
def test_none_param(self):
result = bind_params("%s", (None,), cql_encoders)
result = bind_params("%s", (None,), default_encoder)
self.assertEqual(result, "NULL")
def test_list_collection(self):
result = bind_params("%s", (['a', 'b', 'c'],), cql_encoders)
result = bind_params("%s", (['a', 'b', 'c'],), default_encoder)
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
def test_set_collection(self):
result = bind_params("%s", (set(['a', 'b']),), cql_encoders)
result = bind_params("%s", (set(['a', 'b']),), default_encoder)
self.assertIn(result, ("{ 'a' , 'b' }", "{ 'b' , 'a' }"))
def test_map_collection(self):
@@ -61,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
vals['a'] = 'a'
vals['b'] = 'b'
vals['c'] = 'c'
result = bind_params("%s", (vals,), cql_encoders)
result = bind_params("%s", (vals,), default_encoder)
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
def test_quote_escaping(self):
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), cql_encoders)
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), default_encoder)
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")