Add support for tuple type
This commit is contained in:
		@@ -1086,7 +1086,7 @@ class Session(object):
 | 
			
		||||
        self._metrics = cluster.metrics
 | 
			
		||||
        self._protocol_version = self.cluster.protocol_version
 | 
			
		||||
 | 
			
		||||
        self._encoders = cql_encoders.copy()
 | 
			
		||||
        self.encoders = cql_encoders.copy()
 | 
			
		||||
 | 
			
		||||
        # create connection pools in parallel
 | 
			
		||||
        futures = []
 | 
			
		||||
@@ -1210,7 +1210,7 @@ class Session(object):
 | 
			
		||||
        if isinstance(query, SimpleStatement):
 | 
			
		||||
            query_string = query.query_string
 | 
			
		||||
            if parameters:
 | 
			
		||||
                query_string = bind_params(query.query_string, parameters, self._encoders)
 | 
			
		||||
                query_string = bind_params(query.query_string, parameters, self.encoders)
 | 
			
		||||
            message = QueryMessage(
 | 
			
		||||
                query_string, cl, query.serial_consistency_level,
 | 
			
		||||
                fetch_size, timestamp=timestamp)
 | 
			
		||||
@@ -1460,7 +1460,7 @@ class Session(object):
 | 
			
		||||
                cql_encode_all_types(getattr(val, field_name))
 | 
			
		||||
            ) for field_name in type_meta.field_names)
 | 
			
		||||
 | 
			
		||||
        self._encoders[klass] = encode
 | 
			
		||||
        self.encoders[klass] = encode
 | 
			
		||||
 | 
			
		||||
    def submit(self, fn, *args, **kwargs):
 | 
			
		||||
        """ Internal """
 | 
			
		||||
 
 | 
			
		||||
@@ -774,10 +774,45 @@ class MapType(_ParameterizedType):
 | 
			
		||||
        return buf.getvalue()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UserDefinedType(_ParameterizedType):
 | 
			
		||||
    typename = "'org.apache.cassandra.db.marshal.UserType'"
 | 
			
		||||
class TupleType(_ParameterizedType):
 | 
			
		||||
    typename = 'tuple'
 | 
			
		||||
    num_subtypes = 'UNKNOWN'
 | 
			
		||||
 | 
			
		||||
    FIELD_LENGTH = 4
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def deserialize_safe(cls, byts, protocol_version):
 | 
			
		||||
        proto_version = max(3, protocol_version)
 | 
			
		||||
        p = 0
 | 
			
		||||
        values = []
 | 
			
		||||
        for col_type in cls.subtypes:
 | 
			
		||||
            if p == len(byts):
 | 
			
		||||
                break
 | 
			
		||||
            itemlen = int32_unpack(byts[p:p + 4])
 | 
			
		||||
            p += 4
 | 
			
		||||
            item = byts[p:p + itemlen]
 | 
			
		||||
            p += itemlen
 | 
			
		||||
            # collections inside UDTs are always encoded with at least the
 | 
			
		||||
            # version 3 format
 | 
			
		||||
            values.append(col_type.from_binary(item, proto_version))
 | 
			
		||||
 | 
			
		||||
        if len(values) < len(cls.subtypes):
 | 
			
		||||
            nones = [None] * (len(cls.subtypes) - len(values))
 | 
			
		||||
            values = values + nones
 | 
			
		||||
 | 
			
		||||
        return tuple(values)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def serialize_safe(cls, val, protocol_version):
 | 
			
		||||
        proto_version = max(3, protocol_version)
 | 
			
		||||
        buf = io.BytesIO()
 | 
			
		||||
        for item, subtype in zip(val, cls.subtypes):
 | 
			
		||||
            packed_item = subtype.to_binary(item, proto_version)
 | 
			
		||||
            buf.write(int32_pack(len(packed_item)))
 | 
			
		||||
            buf.write(packed_item)
 | 
			
		||||
        return buf.getvalue()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UserDefinedType(TupleType):
 | 
			
		||||
    typename = "'org.apache.cassandra.db.marshal.UserType'"
 | 
			
		||||
 | 
			
		||||
    _cache = {}
 | 
			
		||||
 | 
			
		||||
@@ -804,24 +839,7 @@ class UserDefinedType(_ParameterizedType):
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def deserialize_safe(cls, byts, protocol_version):
 | 
			
		||||
        proto_version = max(3, protocol_version)
 | 
			
		||||
        p = 0
 | 
			
		||||
        values = []
 | 
			
		||||
        for col_type in cls.subtypes:
 | 
			
		||||
            if p == len(byts):
 | 
			
		||||
                break
 | 
			
		||||
            itemlen = int32_unpack(byts[p:p + cls.FIELD_LENGTH])
 | 
			
		||||
            p += cls.FIELD_LENGTH
 | 
			
		||||
            item = byts[p:p + itemlen]
 | 
			
		||||
            p += itemlen
 | 
			
		||||
            # collections inside UDTs are always encoded with at least the
 | 
			
		||||
            # version 3 format
 | 
			
		||||
            values.append(col_type.from_binary(item, proto_version))
 | 
			
		||||
 | 
			
		||||
        if len(values) < len(cls.subtypes):
 | 
			
		||||
            nones = [None] * (len(cls.subtypes) - len(values))
 | 
			
		||||
            values = values + nones
 | 
			
		||||
 | 
			
		||||
        values = TupleType.deserialize_safe(byts, protocol_version)
 | 
			
		||||
        if cls.mapped_class:
 | 
			
		||||
            return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
 | 
			
		||||
        else:
 | 
			
		||||
 
 | 
			
		||||
@@ -84,6 +84,7 @@ def cql_encode_date(val):
 | 
			
		||||
def cql_encode_sequence(val):
 | 
			
		||||
    return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v)
 | 
			
		||||
                                 for v in val)
 | 
			
		||||
cql_encode_tuple = cql_encode_sequence
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cql_encode_map_collection(val):
 | 
			
		||||
 
 | 
			
		||||
@@ -33,6 +33,7 @@ except ImportError:
 | 
			
		||||
from cassandra import InvalidRequest
 | 
			
		||||
from cassandra.cluster import Cluster
 | 
			
		||||
from cassandra.cqltypes import Int32Type, EMPTY
 | 
			
		||||
from cassandra.encoder import cql_encode_tuple
 | 
			
		||||
from cassandra.query import dict_factory
 | 
			
		||||
from cassandra.util import OrderedDict
 | 
			
		||||
 | 
			
		||||
@@ -398,3 +399,37 @@ class TypeTests(unittest.TestCase):
 | 
			
		||||
        s.execute(prepared, parameters=(dt,))
 | 
			
		||||
        result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b
 | 
			
		||||
        self.assertEqual(dt.utctimetuple(), result.utctimetuple())
 | 
			
		||||
 | 
			
		||||
    def test_tuple_type(self):
 | 
			
		||||
        if self._cass_version < (2, 1, 0):
 | 
			
		||||
            raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
 | 
			
		||||
 | 
			
		||||
        c = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
			
		||||
        s = c.connect()
 | 
			
		||||
        s.encoders[tuple] = cql_encode_tuple
 | 
			
		||||
 | 
			
		||||
        s.execute("""CREATE KEYSPACE test_tuple_type
 | 
			
		||||
            WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
 | 
			
		||||
        s.set_keyspace("test_tuple_type")
 | 
			
		||||
        s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b tuple<ascii, int, boolean>)")
 | 
			
		||||
 | 
			
		||||
        # test non-prepared statement
 | 
			
		||||
        complete = ('foo', 123, True)
 | 
			
		||||
        s.execute("INSERT INTO mytable (a, b) VALUES (0, %s)", parameters=(complete,))
 | 
			
		||||
        result = s.execute("SELECT b FROM mytable WHERE a=0")[0]
 | 
			
		||||
        self.assertEqual(complete, result.b)
 | 
			
		||||
 | 
			
		||||
        partial = ('bar', 456)
 | 
			
		||||
        partial_result = partial + (None,)
 | 
			
		||||
        s.execute("INSERT INTO mytable (a, b) VALUES (1, %s)", parameters=(partial,))
 | 
			
		||||
        result = s.execute("SELECT b FROM mytable WHERE a=1")[0]
 | 
			
		||||
        self.assertEqual(partial_result, result.b)
 | 
			
		||||
 | 
			
		||||
        # test prepared statement
 | 
			
		||||
        prepared = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
 | 
			
		||||
        s.execute(prepared, parameters=(2, complete))
 | 
			
		||||
        s.execute(prepared, parameters=(3, partial))
 | 
			
		||||
 | 
			
		||||
        prepared = s.prepare("SELECT b FROM mytable WHERE a=?")
 | 
			
		||||
        self.assertEqual(complete, s.execute(prepared, (2,))[0].b)
 | 
			
		||||
        self.assertEqual(partial_result, s.execute(prepared, (3,))[0].b)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user