Add support for tuple type
This commit is contained in:
@@ -1086,7 +1086,7 @@ class Session(object):
|
||||
self._metrics = cluster.metrics
|
||||
self._protocol_version = self.cluster.protocol_version
|
||||
|
||||
self._encoders = cql_encoders.copy()
|
||||
self.encoders = cql_encoders.copy()
|
||||
|
||||
# create connection pools in parallel
|
||||
futures = []
|
||||
@@ -1210,7 +1210,7 @@ class Session(object):
|
||||
if isinstance(query, SimpleStatement):
|
||||
query_string = query.query_string
|
||||
if parameters:
|
||||
query_string = bind_params(query.query_string, parameters, self._encoders)
|
||||
query_string = bind_params(query.query_string, parameters, self.encoders)
|
||||
message = QueryMessage(
|
||||
query_string, cl, query.serial_consistency_level,
|
||||
fetch_size, timestamp=timestamp)
|
||||
@@ -1460,7 +1460,7 @@ class Session(object):
|
||||
cql_encode_all_types(getattr(val, field_name))
|
||||
) for field_name in type_meta.field_names)
|
||||
|
||||
self._encoders[klass] = encode
|
||||
self.encoders[klass] = encode
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
""" Internal """
|
||||
|
||||
@@ -774,10 +774,45 @@ class MapType(_ParameterizedType):
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
class UserDefinedType(_ParameterizedType):
|
||||
typename = "'org.apache.cassandra.db.marshal.UserType'"
|
||||
class TupleType(_ParameterizedType):
|
||||
typename = 'tuple'
|
||||
num_subtypes = 'UNKNOWN'
|
||||
|
||||
FIELD_LENGTH = 4
|
||||
@classmethod
|
||||
def deserialize_safe(cls, byts, protocol_version):
|
||||
proto_version = max(3, protocol_version)
|
||||
p = 0
|
||||
values = []
|
||||
for col_type in cls.subtypes:
|
||||
if p == len(byts):
|
||||
break
|
||||
itemlen = int32_unpack(byts[p:p + 4])
|
||||
p += 4
|
||||
item = byts[p:p + itemlen]
|
||||
p += itemlen
|
||||
# collections inside UDTs are always encoded with at least the
|
||||
# version 3 format
|
||||
values.append(col_type.from_binary(item, proto_version))
|
||||
|
||||
if len(values) < len(cls.subtypes):
|
||||
nones = [None] * (len(cls.subtypes) - len(values))
|
||||
values = values + nones
|
||||
|
||||
return tuple(values)
|
||||
|
||||
@classmethod
|
||||
def serialize_safe(cls, val, protocol_version):
|
||||
proto_version = max(3, protocol_version)
|
||||
buf = io.BytesIO()
|
||||
for item, subtype in zip(val, cls.subtypes):
|
||||
packed_item = subtype.to_binary(item, proto_version)
|
||||
buf.write(int32_pack(len(packed_item)))
|
||||
buf.write(packed_item)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
class UserDefinedType(TupleType):
|
||||
typename = "'org.apache.cassandra.db.marshal.UserType'"
|
||||
|
||||
_cache = {}
|
||||
|
||||
@@ -804,24 +839,7 @@ class UserDefinedType(_ParameterizedType):
|
||||
|
||||
@classmethod
|
||||
def deserialize_safe(cls, byts, protocol_version):
|
||||
proto_version = max(3, protocol_version)
|
||||
p = 0
|
||||
values = []
|
||||
for col_type in cls.subtypes:
|
||||
if p == len(byts):
|
||||
break
|
||||
itemlen = int32_unpack(byts[p:p + cls.FIELD_LENGTH])
|
||||
p += cls.FIELD_LENGTH
|
||||
item = byts[p:p + itemlen]
|
||||
p += itemlen
|
||||
# collections inside UDTs are always encoded with at least the
|
||||
# version 3 format
|
||||
values.append(col_type.from_binary(item, proto_version))
|
||||
|
||||
if len(values) < len(cls.subtypes):
|
||||
nones = [None] * (len(cls.subtypes) - len(values))
|
||||
values = values + nones
|
||||
|
||||
values = TupleType.deserialize_safe(byts, protocol_version)
|
||||
if cls.mapped_class:
|
||||
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
|
||||
else:
|
||||
|
||||
@@ -84,6 +84,7 @@ def cql_encode_date(val):
|
||||
def cql_encode_sequence(val):
|
||||
return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v)
|
||||
for v in val)
|
||||
cql_encode_tuple = cql_encode_sequence
|
||||
|
||||
|
||||
def cql_encode_map_collection(val):
|
||||
|
||||
@@ -33,6 +33,7 @@ except ImportError:
|
||||
from cassandra import InvalidRequest
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.cqltypes import Int32Type, EMPTY
|
||||
from cassandra.encoder import cql_encode_tuple
|
||||
from cassandra.query import dict_factory
|
||||
from cassandra.util import OrderedDict
|
||||
|
||||
@@ -398,3 +399,37 @@ class TypeTests(unittest.TestCase):
|
||||
s.execute(prepared, parameters=(dt,))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b
|
||||
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
|
||||
|
||||
def test_tuple_type(self):
|
||||
if self._cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
s.encoders[tuple] = cql_encode_tuple
|
||||
|
||||
s.execute("""CREATE KEYSPACE test_tuple_type
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
|
||||
s.set_keyspace("test_tuple_type")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b tuple<ascii, int, boolean>)")
|
||||
|
||||
# test non-prepared statement
|
||||
complete = ('foo', 123, True)
|
||||
s.execute("INSERT INTO mytable (a, b) VALUES (0, %s)", parameters=(complete,))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a=0")[0]
|
||||
self.assertEqual(complete, result.b)
|
||||
|
||||
partial = ('bar', 456)
|
||||
partial_result = partial + (None,)
|
||||
s.execute("INSERT INTO mytable (a, b) VALUES (1, %s)", parameters=(partial,))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a=1")[0]
|
||||
self.assertEqual(partial_result, result.b)
|
||||
|
||||
# test prepared statement
|
||||
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
|
||||
s.execute(prepared, parameters=(2, complete))
|
||||
s.execute(prepared, parameters=(3, partial))
|
||||
|
||||
prepared = s.prepare("SELECT b FROM mytable WHERE a=?")
|
||||
self.assertEqual(complete, s.execute(prepared, (2,))[0].b)
|
||||
self.assertEqual(partial_result, s.execute(prepared, (3,))[0].b)
|
||||
|
||||
Reference in New Issue
Block a user