Add support for tuple type
This commit is contained in:
@@ -1086,7 +1086,7 @@ class Session(object):
|
|||||||
self._metrics = cluster.metrics
|
self._metrics = cluster.metrics
|
||||||
self._protocol_version = self.cluster.protocol_version
|
self._protocol_version = self.cluster.protocol_version
|
||||||
|
|
||||||
self._encoders = cql_encoders.copy()
|
self.encoders = cql_encoders.copy()
|
||||||
|
|
||||||
# create connection pools in parallel
|
# create connection pools in parallel
|
||||||
futures = []
|
futures = []
|
||||||
@@ -1210,7 +1210,7 @@ class Session(object):
|
|||||||
if isinstance(query, SimpleStatement):
|
if isinstance(query, SimpleStatement):
|
||||||
query_string = query.query_string
|
query_string = query.query_string
|
||||||
if parameters:
|
if parameters:
|
||||||
query_string = bind_params(query.query_string, parameters, self._encoders)
|
query_string = bind_params(query.query_string, parameters, self.encoders)
|
||||||
message = QueryMessage(
|
message = QueryMessage(
|
||||||
query_string, cl, query.serial_consistency_level,
|
query_string, cl, query.serial_consistency_level,
|
||||||
fetch_size, timestamp=timestamp)
|
fetch_size, timestamp=timestamp)
|
||||||
@@ -1460,7 +1460,7 @@ class Session(object):
|
|||||||
cql_encode_all_types(getattr(val, field_name))
|
cql_encode_all_types(getattr(val, field_name))
|
||||||
) for field_name in type_meta.field_names)
|
) for field_name in type_meta.field_names)
|
||||||
|
|
||||||
self._encoders[klass] = encode
|
self.encoders[klass] = encode
|
||||||
|
|
||||||
def submit(self, fn, *args, **kwargs):
|
def submit(self, fn, *args, **kwargs):
|
||||||
""" Internal """
|
""" Internal """
|
||||||
|
|||||||
@@ -774,10 +774,45 @@ class MapType(_ParameterizedType):
|
|||||||
return buf.getvalue()
|
return buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedType(_ParameterizedType):
|
class TupleType(_ParameterizedType):
|
||||||
typename = "'org.apache.cassandra.db.marshal.UserType'"
|
typename = 'tuple'
|
||||||
|
num_subtypes = 'UNKNOWN'
|
||||||
|
|
||||||
FIELD_LENGTH = 4
|
@classmethod
|
||||||
|
def deserialize_safe(cls, byts, protocol_version):
|
||||||
|
proto_version = max(3, protocol_version)
|
||||||
|
p = 0
|
||||||
|
values = []
|
||||||
|
for col_type in cls.subtypes:
|
||||||
|
if p == len(byts):
|
||||||
|
break
|
||||||
|
itemlen = int32_unpack(byts[p:p + 4])
|
||||||
|
p += 4
|
||||||
|
item = byts[p:p + itemlen]
|
||||||
|
p += itemlen
|
||||||
|
# collections inside UDTs are always encoded with at least the
|
||||||
|
# version 3 format
|
||||||
|
values.append(col_type.from_binary(item, proto_version))
|
||||||
|
|
||||||
|
if len(values) < len(cls.subtypes):
|
||||||
|
nones = [None] * (len(cls.subtypes) - len(values))
|
||||||
|
values = values + nones
|
||||||
|
|
||||||
|
return tuple(values)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serialize_safe(cls, val, protocol_version):
|
||||||
|
proto_version = max(3, protocol_version)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
for item, subtype in zip(val, cls.subtypes):
|
||||||
|
packed_item = subtype.to_binary(item, proto_version)
|
||||||
|
buf.write(int32_pack(len(packed_item)))
|
||||||
|
buf.write(packed_item)
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedType(TupleType):
|
||||||
|
typename = "'org.apache.cassandra.db.marshal.UserType'"
|
||||||
|
|
||||||
_cache = {}
|
_cache = {}
|
||||||
|
|
||||||
@@ -804,24 +839,7 @@ class UserDefinedType(_ParameterizedType):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize_safe(cls, byts, protocol_version):
|
def deserialize_safe(cls, byts, protocol_version):
|
||||||
proto_version = max(3, protocol_version)
|
values = TupleType.deserialize_safe(byts, protocol_version)
|
||||||
p = 0
|
|
||||||
values = []
|
|
||||||
for col_type in cls.subtypes:
|
|
||||||
if p == len(byts):
|
|
||||||
break
|
|
||||||
itemlen = int32_unpack(byts[p:p + cls.FIELD_LENGTH])
|
|
||||||
p += cls.FIELD_LENGTH
|
|
||||||
item = byts[p:p + itemlen]
|
|
||||||
p += itemlen
|
|
||||||
# collections inside UDTs are always encoded with at least the
|
|
||||||
# version 3 format
|
|
||||||
values.append(col_type.from_binary(item, proto_version))
|
|
||||||
|
|
||||||
if len(values) < len(cls.subtypes):
|
|
||||||
nones = [None] * (len(cls.subtypes) - len(values))
|
|
||||||
values = values + nones
|
|
||||||
|
|
||||||
if cls.mapped_class:
|
if cls.mapped_class:
|
||||||
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
|
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ def cql_encode_date(val):
|
|||||||
def cql_encode_sequence(val):
|
def cql_encode_sequence(val):
|
||||||
return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v)
|
return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v)
|
||||||
for v in val)
|
for v in val)
|
||||||
|
cql_encode_tuple = cql_encode_sequence
|
||||||
|
|
||||||
|
|
||||||
def cql_encode_map_collection(val):
|
def cql_encode_map_collection(val):
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ except ImportError:
|
|||||||
from cassandra import InvalidRequest
|
from cassandra import InvalidRequest
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.cqltypes import Int32Type, EMPTY
|
from cassandra.cqltypes import Int32Type, EMPTY
|
||||||
|
from cassandra.encoder import cql_encode_tuple
|
||||||
from cassandra.query import dict_factory
|
from cassandra.query import dict_factory
|
||||||
from cassandra.util import OrderedDict
|
from cassandra.util import OrderedDict
|
||||||
|
|
||||||
@@ -398,3 +399,37 @@ class TypeTests(unittest.TestCase):
|
|||||||
s.execute(prepared, parameters=(dt,))
|
s.execute(prepared, parameters=(dt,))
|
||||||
result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b
|
result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b
|
||||||
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
|
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
|
||||||
|
|
||||||
|
def test_tuple_type(self):
|
||||||
|
if self._cass_version < (2, 1, 0):
|
||||||
|
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||||
|
|
||||||
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
|
s = c.connect()
|
||||||
|
s.encoders[tuple] = cql_encode_tuple
|
||||||
|
|
||||||
|
s.execute("""CREATE KEYSPACE test_tuple_type
|
||||||
|
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
|
||||||
|
s.set_keyspace("test_tuple_type")
|
||||||
|
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b tuple<ascii, int, boolean>)")
|
||||||
|
|
||||||
|
# test non-prepared statement
|
||||||
|
complete = ('foo', 123, True)
|
||||||
|
s.execute("INSERT INTO mytable (a, b) VALUES (0, %s)", parameters=(complete,))
|
||||||
|
result = s.execute("SELECT b FROM mytable WHERE a=0")[0]
|
||||||
|
self.assertEqual(complete, result.b)
|
||||||
|
|
||||||
|
partial = ('bar', 456)
|
||||||
|
partial_result = partial + (None,)
|
||||||
|
s.execute("INSERT INTO mytable (a, b) VALUES (1, %s)", parameters=(partial,))
|
||||||
|
result = s.execute("SELECT b FROM mytable WHERE a=1")[0]
|
||||||
|
self.assertEqual(partial_result, result.b)
|
||||||
|
|
||||||
|
# test prepared statement
|
||||||
|
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
|
||||||
|
s.execute(prepared, parameters=(2, complete))
|
||||||
|
s.execute(prepared, parameters=(3, partial))
|
||||||
|
|
||||||
|
prepared = s.prepare("SELECT b FROM mytable WHERE a=?")
|
||||||
|
self.assertEqual(complete, s.execute(prepared, (2,))[0].b)
|
||||||
|
self.assertEqual(partial_result, s.execute(prepared, (3,))[0].b)
|
||||||
|
|||||||
Reference in New Issue
Block a user