Merge branch 'master' into 160-merge

This commit is contained in:
Tyler Hobbs
2014-07-23 15:18:51 -05:00
14 changed files with 515 additions and 227 deletions

View File

@@ -9,6 +9,9 @@ Bug Fixes
* Don't ignore fetch_size arguments to Statement constructors (github-151)
* Allow disabling automatic paging on a per-statement basis when it's
enabled by default for the session (PYTHON-93)
* Raise ValueError when tuple query parameters for prepared statements
have extra items (PYTHON-98)
* Correctly encode nested tuples and UDTs for non-prepared statements (PYTHON-100)
Other
-----

View File

@@ -4,6 +4,10 @@ Releasing
* If dependencies have changed, make sure ``debian/control``
is up to date
* Make sure all patches in ``debian/patches`` still apply cleanly
* Update the debian changelog with the new version::
dch -v '1.0.0'
* Update CHANGELOG.rst
* Update the version in ``cassandra/__init__.py``
* Commit the changelog and version changes

View File

@@ -44,7 +44,7 @@ from itertools import groupby
from cassandra import (ConsistencyLevel, AuthenticationFailed,
InvalidRequest, OperationTimedOut, UnsupportedOperation)
from cassandra.connection import ConnectionException, ConnectionShutdown
from cassandra.encoder import cql_encode_all_types, cql_encoders
from cassandra.encoder import Encoder
from cassandra.protocol import (QueryMessage, ResultMessage,
ErrorMessage, ReadTimeoutErrorMessage,
WriteTimeoutErrorMessage,
@@ -1162,25 +1162,25 @@ class Session(object):
.. versionadded:: 2.1.0
"""
encoders = None
encoder = None
"""
A map of python types to CQL encoder functions that will be used when
formatting query parameters for non-prepared statements. This mapping
is not used for prepared statements (because prepared statements
give the driver more information about what CQL types are expected, allowing
it to accept a wider range of python types).
A :class:`~cassandra.encoder.Encoder` instance that will be used when
formatting query parameters for non-prepared statements. This is not used
for prepared statements (because prepared statements give the driver more
information about what CQL types are expected, allowing it to accept a
wider range of python types).
This mapping can be be modified by users as they see fit. Functions from
:mod:`cassandra.encoder` should be used, if possible, because they take
precautions to avoid injections and properly sanitize data.
The encoder uses a mapping from python types to encoder methods (for
specific CQL types). This mapping can be be modified by users as they see
fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping
values if possible, because they take precautions to avoid injections and
properly sanitize data.
Example::
from cassandra.encoder import cql_encode_tuple
cluster = Cluster()
session = cluster.connect("mykeyspace")
session.encoders[tuple] = cql_encode_tuple
session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple
session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple<int, ascii>)")
session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')])
@@ -1202,7 +1202,7 @@ class Session(object):
self._metrics = cluster.metrics
self._protocol_version = self.cluster.protocol_version
self.encoders = cql_encoders.copy()
self.encoder = Encoder()
# create connection pools in parallel
futures = []
@@ -1328,7 +1328,7 @@ class Session(object):
if six.PY2 and isinstance(query_string, six.text_type):
query_string = query_string.encode('utf-8')
if parameters:
query_string = bind_params(query_string, parameters, self.encoders)
query_string = bind_params(query_string, parameters, self.encoder)
message = QueryMessage(
query_string, cl, query.serial_consistency_level,
fetch_size, timestamp=timestamp)
@@ -1585,13 +1585,13 @@ class Session(object):
raise UserTypeDoesNotExist(
'User type %s does not exist in keyspace %s' % (user_type, keyspace))
def encode(val):
def encode(encoder_self, val):
return '{ %s }' % ' , '.join('%s : %s' % (
field_name,
cql_encode_all_types(getattr(val, field_name))
encoder_self.cql_encode_all_types(getattr(val, field_name))
) for field_name in type_meta.field_names)
self.encoders[klass] = encode
self.encoder.mapping[klass] = encode
def submit(self, fn, *args, **kwargs):
""" Internal """

View File

@@ -806,6 +806,10 @@ class TupleType(_ParameterizedType):
@classmethod
def serialize_safe(cls, val, protocol_version):
if len(val) > len(cls.subtypes):
raise ValueError("Expected %d items in a tuple, but only got %d: %s" %
(len(cls.subtypes), len(val), val))
proto_version = max(3, protocol_version)
buf = io.BytesIO()
for item, subtype in zip(val, cls.subtypes):

View File

@@ -49,154 +49,156 @@ def cql_quote(term):
return str(term)
def cql_encode_none(val):
class ValueSequence(list):
pass
class Encoder(object):
"""
Converts :const:`None` to the string 'NULL'.
A container for mapping python types to CQL string literals when working
with non-prepared statements. The type :attr:`~.Encoder.mapping` can be
directly customized by users.
"""
return 'NULL'
def cql_encode_unicode(val):
mapping = None
"""
Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping.
A map of python types to encoder functions.
"""
return cql_quote(val.encode('utf-8'))
def __init__(self):
self.mapping = {
float: self.cql_encode_object,
bytearray: self.cql_encode_bytes,
str: self.cql_encode_str,
int: self.cql_encode_object,
UUID: self.cql_encode_object,
datetime.datetime: self.cql_encode_datetime,
datetime.date: self.cql_encode_date,
dict: self.cql_encode_map_collection,
OrderedDict: self.cql_encode_map_collection,
list: self.cql_encode_list_collection,
tuple: self.cql_encode_list_collection,
set: self.cql_encode_set_collection,
frozenset: self.cql_encode_set_collection,
types.GeneratorType: self.cql_encode_list_collection,
ValueSequence: self.cql_encode_sequence
}
def cql_encode_str(val):
"""
Escapes quotes in :class:`str` objects.
"""
return cql_quote(val)
if six.PY2:
self.mapping.update({
unicode: self.cql_encode_unicode,
buffer: self.cql_encode_bytes,
long: self.cql_encode_object,
types.NoneType: self.cql_encode_none,
})
else:
self.mapping.update({
memoryview: self.cql_encode_bytes,
bytes: self.cql_encode_bytes,
type(None): self.cql_encode_none,
})
# sortedset is optional
try:
from blist import sortedset
self.mapping.update({
sortedset: self.cql_encode_set_collection
})
except ImportError:
pass
if six.PY3:
def cql_encode_bytes(val):
return (b'0x' + hexlify(val)).decode('utf-8')
elif sys.version_info >= (2, 7):
def cql_encode_bytes(val): # noqa
return b'0x' + hexlify(val)
else:
# python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(val): # noqa
return b'0x' + hexlify(buffer(val))
def cql_encode_none(self, val):
"""
Converts :const:`None` to the string 'NULL'.
"""
return 'NULL'
def cql_encode_unicode(self, val):
"""
Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping.
"""
return cql_quote(val.encode('utf-8'))
def cql_encode_object(val):
"""
Default encoder for all objects that do not have a specific encoder function
registered. This function simply calls :meth:`str()` on the object.
"""
return str(val)
def cql_encode_str(self, val):
"""
Escapes quotes in :class:`str` objects.
"""
return cql_quote(val)
if six.PY3:
def cql_encode_bytes(self, val):
return (b'0x' + hexlify(val)).decode('utf-8')
elif sys.version_info >= (2, 7):
def cql_encode_bytes(self, val): # noqa
return b'0x' + hexlify(val)
else:
# python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(self, val): # noqa
return b'0x' + hexlify(buffer(val))
def cql_encode_datetime(val):
"""
Converts a :class:`datetime.datetime` object to a (string) integer timestamp
with millisecond precision.
"""
timestamp = calendar.timegm(val.utctimetuple())
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
def cql_encode_object(self, val):
"""
Default encoder for all objects that do not have a specific encoder function
registered. This function simply calls :meth:`str()` on the object.
"""
return str(val)
def cql_encode_datetime(self, val):
"""
Converts a :class:`datetime.datetime` object to a (string) integer timestamp
with millisecond precision.
"""
timestamp = calendar.timegm(val.utctimetuple())
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
def cql_encode_date(val):
"""
Converts a :class:`datetime.date` object to a string with format
``YYYY-MM-DD-0000``.
"""
return "'%s'" % val.strftime('%Y-%m-%d-0000')
def cql_encode_date(self, val):
"""
Converts a :class:`datetime.date` object to a string with format
``YYYY-MM-DD-0000``.
"""
return "'%s'" % val.strftime('%Y-%m-%d-0000')
def cql_encode_sequence(self, val):
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``IN`` value lists.
"""
return '( %s )' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v)
for v in val)
def cql_encode_sequence(val):
cql_encode_tuple = cql_encode_sequence
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``IN`` value lists.
is suitable for ``tuple`` type columns.
"""
return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v)
for v in val)
def cql_encode_map_collection(self, val):
"""
Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``.
This is suitable for ``map`` type columns.
"""
return '{ %s }' % ' , '.join('%s : %s' % (
self.mapping.get(type(k), self.cql_encode_object)(k),
self.mapping.get(type(v), self.cql_encode_object)(v)
) for k, v in six.iteritems(val))
cql_encode_tuple = cql_encode_sequence
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``tuple`` type columns.
"""
def cql_encode_list_collection(self, val):
"""
Converts a sequence to a string of the form ``[item1, item2, ...]``. This
is suitable for ``list`` type columns.
"""
return '[ %s ]' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
def cql_encode_set_collection(self, val):
"""
Converts a sequence to a string of the form ``{item1, item2, ...}``. This
is suitable for ``set`` type columns.
"""
return '{ %s }' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
def cql_encode_map_collection(val):
"""
Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``.
This is suitable for ``map`` type columns.
"""
return '{ %s }' % ' , '.join('%s : %s' % (
cql_encode_all_types(k),
cql_encode_all_types(v)
) for k, v in six.iteritems(val))
def cql_encode_list_collection(val):
"""
Converts a sequence to a string of the form ``[item1, item2, ...]``. This
is suitable for ``list`` type columns.
"""
return '[ %s ]' % ' , '.join(map(cql_encode_all_types, val))
def cql_encode_set_collection(val):
"""
Converts a sequence to a string of the form ``{item1, item2, ...}``. This
is suitable for ``set`` type columns.
"""
return '{ %s }' % ' , '.join(map(cql_encode_all_types, val))
def cql_encode_all_types(val):
"""
Converts any type into a CQL string, defaulting to ``cql_encode_object``
if :attr:`~.cql_encoders` does not contain an entry for the type.
"""
return cql_encoders.get(type(val), cql_encode_object)(val)
cql_encoders = {
float: cql_encode_object,
bytearray: cql_encode_bytes,
str: cql_encode_str,
int: cql_encode_object,
UUID: cql_encode_object,
datetime.datetime: cql_encode_datetime,
datetime.date: cql_encode_date,
dict: cql_encode_map_collection,
OrderedDict: cql_encode_map_collection,
list: cql_encode_list_collection,
tuple: cql_encode_list_collection,
set: cql_encode_set_collection,
frozenset: cql_encode_set_collection,
types.GeneratorType: cql_encode_list_collection
}
"""
A map of python types to encoder functions.
"""
if six.PY2:
cql_encoders.update({
unicode: cql_encode_unicode,
buffer: cql_encode_bytes,
long: cql_encode_object,
types.NoneType: cql_encode_none,
})
else:
cql_encoders.update({
memoryview: cql_encode_bytes,
bytes: cql_encode_bytes,
type(None): cql_encode_none,
})
# sortedset is optional
try:
from blist import sortedset
cql_encoders.update({
sortedset: cql_encode_set_collection
})
except ImportError:
pass
def cql_encode_all_types(self, val):
"""
Converts any type into a CQL string, defaulting to ``cql_encode_object``
if :attr:`~Encoder.mapping` does not contain an entry for the type.
"""
return self.mapping.get(type(val), self.cql_encode_object)(val)

View File

@@ -27,8 +27,8 @@ import six
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cqltypes import unix_time_from_uuid1
from cassandra.encoder import (cql_encoders, cql_encode_object,
cql_encode_sequence)
from cassandra.encoder import Encoder
import cassandra.encoder
from cassandra.util import OrderedDict
import logging
@@ -57,7 +57,7 @@ def tuple_factory(colnames, rows):
Example::
>>> from cassandra.query import named_tuple_factory
>>> from cassandra.query import tuple_factory
>>> session = cluster.connect('mykeyspace')
>>> session.row_factory = tuple_factory
>>> rows = session.execute("SELECT name, age FROM users LIMIT 1")
@@ -625,8 +625,8 @@ class BatchStatement(Statement):
"""
if isinstance(statement, six.string_types):
if parameters:
encoders = cql_encoders if self._session is None else self._session.encoders
statement = bind_params(statement, parameters, encoders)
encoder = Encoder() if self._session is None else self._session.encoder
statement = bind_params(statement, parameters, encoder)
self._statements_and_parameters.append((False, statement, ()))
elif isinstance(statement, PreparedStatement):
query_id = statement.query_id
@@ -644,8 +644,8 @@ class BatchStatement(Statement):
# it must be a SimpleStatement
query_string = statement.query_string
if parameters:
encoders = cql_encoders if self._session is None else self._session.encoders
query_string = bind_params(query_string, parameters, encoders)
encoder = Encoder() if self._session is None else self._session.encoder
query_string = bind_params(query_string, parameters, encoder)
self._statements_and_parameters.append((False, query_string, ()))
return self
@@ -665,33 +665,27 @@ class BatchStatement(Statement):
__repr__ = __str__
class ValueSequence(object):
"""
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
ValueSequence = cassandra.encoder.ValueSequence
"""
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
This is typically needed when supplying a list of keys to select.
For example::
This is typically needed when supplying a list of keys to select.
For example::
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
"""
def __init__(self, sequence):
self.sequence = sequence
def __str__(self):
return cql_encode_sequence(self.sequence)
"""
def bind_params(query, params, encoders):
def bind_params(query, params, encoder):
if isinstance(params, dict):
return query % dict((k, encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params))
else:
return query % tuple(encoders.get(type(v), cql_encode_object)(v) for v in params)
return query % tuple(encoder.cql_encode_all_types(v) for v in params)
class TraceUnavailable(Exception):

View File

@@ -67,7 +67,7 @@
.. autoattribute:: use_client_timestamp
.. autoattribute:: encoders
.. autoattribute:: encoder
.. automethod:: execute(statement[, parameters][, timeout][, trace])

View File

@@ -3,41 +3,34 @@
.. module:: cassandra.encoder
.. data:: cql_encoders
.. autoclass:: Encoder ()
A map of python types to encoder functions.
.. autoattribute:: cassandra.encoder.Encoder.mapping
.. autofunction:: cql_encode_none ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_none ()
.. autofunction:: cql_encode_object ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_object ()
.. autofunction:: cql_encode_all_types ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_all_types ()
.. autofunction:: cql_encode_sequence ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_sequence ()
String Types
------------
.. automethod:: cassandra.encoder.Encoder.cql_encode_str ()
.. autofunction:: cql_encode_str ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_unicode ()
.. autofunction:: cql_encode_unicode ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_bytes ()
.. autofunction:: cql_encode_bytes ()
Converts strings, buffers, and bytearrays into CQL blob literals.
Date Types
----------
.. automethod:: cassandra.encoder.Encoder.cql_encode_datetime ()
.. autofunction:: cql_encode_datetime ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_date ()
.. autofunction:: cql_encode_date ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_map_collection ()
Collection Types
----------------
.. automethod:: cassandra.encoder.Encoder.cql_encode_list_collection ()
.. autofunction:: cql_encode_map_collection ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_set_collection ()
.. autofunction:: cql_encode_list_collection ()
.. autofunction:: cql_encode_set_collection ()
.. autofunction:: cql_encode_tuple ()
.. automethod:: cql_encode_tuple ()

View File

@@ -34,8 +34,18 @@
.. autoattribute:: COUNTER
.. autoclass:: ValueSequence
:members:
.. autoclass:: cassandra.query.ValueSequence
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
This is typically needed when supplying a list of keys to select.
For example::
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
.. autoclass:: QueryTrace ()
:members:

View File

@@ -16,6 +16,11 @@ import datetime
from uuid import UUID
import pytz
try:
from blist import sortedset
except ImportError:
sortedset = set # noqa
DATA_TYPE_PRIMITIVES = [
'ascii',
'bigint',
@@ -113,3 +118,22 @@ def get_sample(datatype):
"""
return SAMPLE_DATA[datatype]
def get_nonprim_sample(non_prim_type, datatype):
"""
Helper method to access created sample data for non-primitives
"""
if non_prim_type == 'list':
return [get_sample(datatype), get_sample(datatype)]
elif non_prim_type == 'set':
return sortedset([get_sample(datatype)])
elif non_prim_type == 'map':
if datatype == 'blob':
return {get_sample('ascii'): get_sample(datatype)}
else:
return {get_sample(datatype): get_sample(datatype)}
elif non_prim_type == 'tuple':
return (get_sample(datatype),)
else:
raise Exception('Missing handling of non-primitive type {0}.'.format(non_prim_type))

View File

@@ -18,9 +18,8 @@ except ImportError:
import unittest # noqa
from cassandra import ConsistencyLevel
from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
SimpleStatement, BatchStatement, BatchType,
dict_factory)
from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement,
BatchStatement, BatchType, dict_factory)
from cassandra.cluster import Cluster
from cassandra.policies import HostDistance
@@ -45,14 +44,6 @@ class QueryTest(unittest.TestCase):
session.execute(bound)
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_value_sequence(self):
"""
Test the output of ValueSequences()
"""
my_user_ids = ('alice', 'bob', 'charles')
self.assertEqual(str(ValueSequence(my_user_ids)), "( 'alice' , 'bob' , 'charles' )")
def test_trace_prints_okay(self):
"""
Code coverage to ensure trace prints to string without error

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.integration.datatype_utils import get_sample, DATA_TYPE_PRIMITIVES
from tests.integration.datatype_utils import get_sample, DATA_TYPE_PRIMITIVES, DATA_TYPE_NON_PRIMITIVE_NAMES
try:
import unittest2 as unittest
@@ -34,7 +34,6 @@ except ImportError:
from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.encoder import cql_encode_tuple
from cassandra.query import dict_factory
from cassandra.util import OrderedDict
@@ -416,7 +415,7 @@ class TypeTests(unittest.TestCase):
s = c.connect()
# use this encoder in order to insert tuples
s.encoders[tuple] = cql_encode_tuple
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_type
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
@@ -448,6 +447,9 @@ class TypeTests(unittest.TestCase):
s.execute(prepared, parameters=(4, partial))
s.execute(prepared, parameters=(5, subpartial))
# extra items in the tuple should result in an error
self.assertRaises(ValueError, s.execute, prepared, parameters=(0, (1, 2, 3, 4, 5, 6)))
prepared = s.prepare("SELECT b FROM mytable WHERE a=?")
self.assertEqual(complete, s.execute(prepared, (3,))[0].b)
self.assertEqual(partial_result, s.execute(prepared, (4,))[0].b)
@@ -468,7 +470,7 @@ class TypeTests(unittest.TestCase):
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
s.row_factory = dict_factory
s.encoders[tuple] = cql_encode_tuple
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_type_varying_lengths
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
@@ -491,7 +493,7 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
self.assertEqual(tuple(created_tuple), result['v_%s' % i])
def test_tuple_subtypes(self):
def test_tuple_primitive_subtypes(self):
"""
Ensure tuple subtypes are appropriately handled.
"""
@@ -501,11 +503,11 @@ class TypeTests(unittest.TestCase):
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.encoders[tuple] = cql_encode_tuple
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_types
s.execute("""CREATE KEYSPACE test_tuple_primitive_subtypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_tuple_types")
s.set_keyspace("test_tuple_primitive_subtypes")
s.execute("CREATE TABLE mytable ("
"k int PRIMARY KEY, "
@@ -523,6 +525,151 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT v FROM mytable WHERE k=%s", (i,))[0]
self.assertEqual(response_tuple, result.v)
def nested_tuples_schema_helper(self, depth):
"""
Helper method for creating nested tuple schema
"""
if depth == 0:
return 'int'
else:
return 'tuple<%s>' % self.nested_tuples_schema_helper(depth - 1)
def nested_tuples_creator_helper(self, depth):
"""
Helper method for creating nested tuples
"""
if depth == 0:
return 303
else:
return (self.nested_tuples_creator_helper(depth - 1), )
def test_tuple_non_primitive_subtypes(self):
"""
Ensure tuple subtypes are appropriately handled for maps, sets, and lists.
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
s.row_factory = dict_factory
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_non_primitive_subtypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_tuple_non_primitive_subtypes")
values = []
# create list values
for datatype in DATA_TYPE_PRIMITIVES:
values.append('v_{} tuple<list<{}>>'.format(len(values), datatype))
# create set values
for datatype in DATA_TYPE_PRIMITIVES:
values.append('v_{} tuple<set<{}>>'.format(len(values), datatype))
# create map values
for datatype in DATA_TYPE_PRIMITIVES:
datatype_1 = datatype_2 = datatype
if datatype == 'blob':
# unhashable type: 'bytearray'
datatype_1 = 'ascii'
values.append('v_{} tuple<map<{}, {}>>'.format(len(values), datatype_1, datatype_2))
# make sure we're testing all non primitive data types in the future
if set(DATA_TYPE_NON_PRIMITIVE_NAMES) != set(['tuple', 'list', 'map', 'set']):
raise NotImplemented('Missing datatype not implemented: {}'.format(
set(DATA_TYPE_NON_PRIMITIVE_NAMES) - set(['tuple', 'list', 'map', 'set'])
))
# create table
s.execute("CREATE TABLE mytable ("
"k int PRIMARY KEY, "
"%s)" % ', '.join(values))
i = 0
# test tuple<list<datatype>>
for datatype in DATA_TYPE_PRIMITIVES:
created_tuple = tuple([[get_sample(datatype)]])
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
i += 1
# test tuple<set<datatype>>
for datatype in DATA_TYPE_PRIMITIVES:
created_tuple = tuple([sortedset([get_sample(datatype)])])
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
i += 1
# test tuple<map<datatype, datatype>>
for datatype in DATA_TYPE_PRIMITIVES:
if datatype == 'blob':
# unhashable type: 'bytearray'
created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}])
else:
created_tuple = tuple([{get_sample(datatype): get_sample(datatype)}])
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
i += 1
def test_nested_tuples(self):
"""
Ensure nested are appropriately handled.
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
s.row_factory = dict_factory
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_nested_tuples
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_nested_tuples")
# create a table with multiple sizes of nested tuples
s.execute("CREATE TABLE mytable ("
"k int PRIMARY KEY, "
"v_1 %s,"
"v_2 %s,"
"v_3 %s,"
"v_128 %s"
")" % (self.nested_tuples_schema_helper(1),
self.nested_tuples_schema_helper(2),
self.nested_tuples_schema_helper(3),
self.nested_tuples_schema_helper(128)))
for i in (1, 2, 3, 128):
# create tuple
created_tuple = self.nested_tuples_creator_helper(i)
# write tuple
s.execute("INSERT INTO mytable (k, v_%s) VALUES (%s, %s)", (i, i, created_tuple))
# verify tuple was written and read correctly
result = s.execute("SELECT v_%s FROM mytable WHERE k=%s", (i, i))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
def test_unicode_query_string(self):
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()

View File

@@ -26,6 +26,8 @@ from collections import namedtuple
from cassandra.cluster import Cluster, UserTypeDoesNotExist
from tests.integration import get_server_versions, PROTOCOL_VERSION
from tests.integration.datatype_utils import get_sample, get_nonprim_sample,\
DATA_TYPE_PRIMITIVES, DATA_TYPE_NON_PRIMITIVE_NAMES
class TypeTests(unittest.TestCase):
@@ -325,3 +327,117 @@ class TypeTests(unittest.TestCase):
User = namedtuple('user', ('age', 'name'))
self.assertRaises(UserTypeDoesNotExist, c.register_user_type, "some_bad_keyspace", "user", User)
self.assertRaises(UserTypeDoesNotExist, c.register_user_type, "system", "user", User)
def test_primitive_datatypes(self):
"""
Test for inserting various types of DATA_TYPE_PRIMITIVES into UDT's
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# create keyspace
s.execute("""
CREATE KEYSPACE test_primitive_datatypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("test_primitive_datatypes")
# create UDT
alpha_type_list = []
start_index = ord('a')
for i, datatype in enumerate(DATA_TYPE_PRIMITIVES):
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
s.execute("""
CREATE TYPE alldatatypes ({0})
""".format(', '.join(alpha_type_list))
)
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b alldatatypes)")
# register UDT
alphabet_list = []
for i in range(ord('a'), ord('a') + len(DATA_TYPE_PRIMITIVES)):
alphabet_list.append('{}'.format(chr(i)))
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
c.register_user_type("test_primitive_datatypes", "alldatatypes", Alldatatypes)
# insert UDT data
params = []
for datatype in DATA_TYPE_PRIMITIVES:
params.append((get_sample(datatype)))
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(insert, (0, Alldatatypes(*params)))
# retrieve and verify data
results = s.execute("SELECT * FROM mytable")
self.assertEqual(1, len(results))
row = results[0].b
for expected, actual in zip(params, row):
self.assertEqual(expected, actual)
c.shutdown()
def test_nonprimitive_datatypes(self):
"""
Test for inserting various types of DATA_TYPE_NON_PRIMITIVE into UDT's
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# create keyspace
s.execute("""
CREATE KEYSPACE test_nonprimitive_datatypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("test_nonprimitive_datatypes")
# create UDT
alpha_type_list = []
start_index = ord('a')
for i, nonprim_datatype in enumerate(DATA_TYPE_NON_PRIMITIVE_NAMES):
for j, datatype in enumerate(DATA_TYPE_PRIMITIVES):
if nonprim_datatype == "map":
type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j),
nonprim_datatype, datatype)
else:
type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j),
nonprim_datatype, datatype)
alpha_type_list.append(type_string)
s.execute("""
CREATE TYPE alldatatypes ({0})
""".format(', '.join(alpha_type_list))
)
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b alldatatypes)")
# register UDT
alphabet_list = []
for i in range(ord('a'), ord('a') + len(DATA_TYPE_NON_PRIMITIVE_NAMES)):
for j in range(ord('a'), ord('a') + len(DATA_TYPE_PRIMITIVES)):
alphabet_list.append('{0}_{1}'.format(chr(i), chr(j)))
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
c.register_user_type("test_nonprimitive_datatypes", "alldatatypes", Alldatatypes)
# insert UDT data
params = []
for nonprim_datatype in DATA_TYPE_NON_PRIMITIVE_NAMES:
for datatype in DATA_TYPE_PRIMITIVES:
params.append((get_nonprim_sample(nonprim_datatype, datatype)))
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(insert, (0, Alldatatypes(*params)))
# retrieve and verify data
results = s.execute("SELECT * FROM mytable")
self.assertEqual(1, len(results))
row = results[0].b
for expected, actual in zip(params, row):
self.assertEqual(expected, actual)
c.shutdown()

View File

@@ -17,7 +17,7 @@ try:
except ImportError:
import unittest # noqa
from cassandra.encoder import cql_encoders
from cassandra.encoder import Encoder
from cassandra.query import bind_params, ValueSequence
from cassandra.query import PreparedStatement, BoundStatement
from cassandra.cqltypes import Int32Type
@@ -29,31 +29,31 @@ from six.moves import xrange
class ParamBindingTest(unittest.TestCase):
def test_bind_sequence(self):
result = bind_params("%s %s %s", (1, "a", 2.0), cql_encoders)
result = bind_params("%s %s %s", (1, "a", 2.0), Encoder())
self.assertEqual(result, "1 'a' 2.0")
def test_bind_map(self):
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), cql_encoders)
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), Encoder())
self.assertEqual(result, "1 'a' 2.0")
def test_sequence_param(self):
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), cql_encoders)
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), Encoder())
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
def test_generator_param(self):
result = bind_params("%s", ((i for i in xrange(3)),), cql_encoders)
result = bind_params("%s", ((i for i in xrange(3)),), Encoder())
self.assertEqual(result, "[ 0 , 1 , 2 ]")
def test_none_param(self):
result = bind_params("%s", (None,), cql_encoders)
result = bind_params("%s", (None,), Encoder())
self.assertEqual(result, "NULL")
def test_list_collection(self):
result = bind_params("%s", (['a', 'b', 'c'],), cql_encoders)
result = bind_params("%s", (['a', 'b', 'c'],), Encoder())
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
def test_set_collection(self):
result = bind_params("%s", (set(['a', 'b']),), cql_encoders)
result = bind_params("%s", (set(['a', 'b']),), Encoder())
self.assertIn(result, ("{ 'a' , 'b' }", "{ 'b' , 'a' }"))
def test_map_collection(self):
@@ -61,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
vals['a'] = 'a'
vals['b'] = 'b'
vals['c'] = 'c'
result = bind_params("%s", (vals,), cql_encoders)
result = bind_params("%s", (vals,), Encoder())
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
def test_quote_escaping(self):
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), cql_encoders)
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), Encoder())
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")