Add support for tuple type

This commit is contained in:
Tyler Hobbs
2014-06-24 14:46:21 -05:00
parent f1d143cdfe
commit 4567f2d610
4 changed files with 78 additions and 24 deletions

View File

@@ -1086,7 +1086,7 @@ class Session(object):
self._metrics = cluster.metrics
self._protocol_version = self.cluster.protocol_version
self._encoders = cql_encoders.copy()
self.encoders = cql_encoders.copy()
# create connection pools in parallel
futures = []
@@ -1210,7 +1210,7 @@ class Session(object):
if isinstance(query, SimpleStatement):
query_string = query.query_string
if parameters:
query_string = bind_params(query.query_string, parameters, self._encoders)
query_string = bind_params(query.query_string, parameters, self.encoders)
message = QueryMessage(
query_string, cl, query.serial_consistency_level,
fetch_size, timestamp=timestamp)
@@ -1460,7 +1460,7 @@ class Session(object):
cql_encode_all_types(getattr(val, field_name))
) for field_name in type_meta.field_names)
self._encoders[klass] = encode
self.encoders[klass] = encode
def submit(self, fn, *args, **kwargs):
""" Internal """

View File

@@ -774,10 +774,45 @@ class MapType(_ParameterizedType):
return buf.getvalue()
class UserDefinedType(_ParameterizedType):
typename = "'org.apache.cassandra.db.marshal.UserType'"
class TupleType(_ParameterizedType):
typename = 'tuple'
num_subtypes = 'UNKNOWN'
FIELD_LENGTH = 4
@classmethod
def deserialize_safe(cls, byts, protocol_version):
proto_version = max(3, protocol_version)
p = 0
values = []
for col_type in cls.subtypes:
if p == len(byts):
break
itemlen = int32_unpack(byts[p:p + 4])
p += 4
item = byts[p:p + itemlen]
p += itemlen
# collections inside UDTs are always encoded with at least the
# version 3 format
values.append(col_type.from_binary(item, proto_version))
if len(values) < len(cls.subtypes):
nones = [None] * (len(cls.subtypes) - len(values))
values = values + nones
return tuple(values)
@classmethod
def serialize_safe(cls, val, protocol_version):
proto_version = max(3, protocol_version)
buf = io.BytesIO()
for item, subtype in zip(val, cls.subtypes):
packed_item = subtype.to_binary(item, proto_version)
buf.write(int32_pack(len(packed_item)))
buf.write(packed_item)
return buf.getvalue()
class UserDefinedType(TupleType):
typename = "'org.apache.cassandra.db.marshal.UserType'"
_cache = {}
@@ -804,24 +839,7 @@ class UserDefinedType(_ParameterizedType):
@classmethod
def deserialize_safe(cls, byts, protocol_version):
proto_version = max(3, protocol_version)
p = 0
values = []
for col_type in cls.subtypes:
if p == len(byts):
break
itemlen = int32_unpack(byts[p:p + cls.FIELD_LENGTH])
p += cls.FIELD_LENGTH
item = byts[p:p + itemlen]
p += itemlen
# collections inside UDTs are always encoded with at least the
# version 3 format
values.append(col_type.from_binary(item, proto_version))
if len(values) < len(cls.subtypes):
nones = [None] * (len(cls.subtypes) - len(values))
values = values + nones
values = TupleType.deserialize_safe(byts, protocol_version)
if cls.mapped_class:
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
else:

View File

@@ -84,6 +84,7 @@ def cql_encode_date(val):
def cql_encode_sequence(val):
return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v)
for v in val)
cql_encode_tuple = cql_encode_sequence
def cql_encode_map_collection(val):

View File

@@ -33,6 +33,7 @@ except ImportError:
from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.encoder import cql_encode_tuple
from cassandra.query import dict_factory
from cassandra.util import OrderedDict
@@ -398,3 +399,37 @@ class TypeTests(unittest.TestCase):
s.execute(prepared, parameters=(dt,))
result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
def test_tuple_type(self):
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.encoders[tuple] = cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_type
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_tuple_type")
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b tuple<ascii, int, boolean>)")
# test non-prepared statement
complete = ('foo', 123, True)
s.execute("INSERT INTO mytable (a, b) VALUES (0, %s)", parameters=(complete,))
result = s.execute("SELECT b FROM mytable WHERE a=0")[0]
self.assertEqual(complete, result.b)
partial = ('bar', 456)
partial_result = partial + (None,)
s.execute("INSERT INTO mytable (a, b) VALUES (1, %s)", parameters=(partial,))
result = s.execute("SELECT b FROM mytable WHERE a=1")[0]
self.assertEqual(partial_result, result.b)
# test prepared statement
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(prepared, parameters=(2, complete))
s.execute(prepared, parameters=(3, partial))
prepared = s.prepare("SELECT b FROM mytable WHERE a=?")
self.assertEqual(complete, s.execute(prepared, (2,))[0].b)
self.assertEqual(partial_result, s.execute(prepared, (3,))[0].b)