Make key sequences require type wrapping, not collections

This commit is contained in:
Tyler Hobbs
2013-07-01 11:22:49 -05:00
parent f22912b3d7
commit eee785d6bb
3 changed files with 12 additions and 19 deletions

View File

@@ -793,8 +793,9 @@ cql_encoders = {
datetime.datetime: cql_encode_datetime,
datetime.date: cql_encode_date,
dict: cql_encode_map_collection,
list: cql_encode_sequence,
tuple: cql_encode_sequence,
set: cql_encode_sequence,
list: cql_encode_list_collection,
tuple: cql_encode_list_collection,
set: cql_encode_set_collection,
frozenset: cql_encode_set_collection,
types.GeneratorType: cql_encode_sequence
}

View File

@@ -2,9 +2,7 @@ import struct
from cassandra import ConsistencyLevel
from cassandra.decoder import (cql_encoders, cql_encode_object,
cql_encode_map_collection,
cql_encode_set_collection,
cql_encode_list_collection)
cql_encode_sequence)
class Query(object):
@@ -138,19 +136,13 @@ class BoundStatement(Query):
return self._routing_key
class ColumnCollection(object):
class KeySequence(object):
def __init__(self, sequence):
self.sequence = sequence
def __str__(self):
s = self.sequence
if isinstance(s, dict):
return cql_encode_map_collection(s)
elif isinstance(s, (set, frozenset)):
return cql_encode_set_collection(s)
else:
return cql_encode_list_collection(s)
return cql_encode_sequence(self.sequence)
def bind_params(query, params):

View File

@@ -1,6 +1,6 @@
import unittest
from cassandra.query import bind_params, ColumnCollection
from cassandra.query import bind_params, KeySequence
class ParamBindingTest(unittest.TestCase):
@@ -13,7 +13,7 @@ class ParamBindingTest(unittest.TestCase):
self.assertEquals(result, "1 'a' 2.0")
def test_sequence_param(self):
result = bind_params("%s", ((1, "a", 2.0),))
result = bind_params("%s", (KeySequence((1, "a", 2.0)),))
self.assertEquals(result, "( 1 , 'a' , 2.0 )")
def test_generator_param(self):
@@ -25,15 +25,15 @@ class ParamBindingTest(unittest.TestCase):
self.assertEquals(result, "NULL")
def test_list_collection(self):
result = bind_params("%s", (ColumnCollection(['a', 'b', 'c']),))
result = bind_params("%s", (['a', 'b', 'c'],))
self.assertEquals(result, "[ 'a' , 'b' , 'c' ]")
def test_set_collection(self):
result = bind_params("%s", (ColumnCollection({'a', 'b', 'c'}),))
result = bind_params("%s", ({'a', 'b', 'c'},))
self.assertEquals(result, "{ 'a' , 'c' , 'b' }")
def test_map_collection(self):
result = bind_params("%s", (ColumnCollection({'a': 'a', 'b': 'b'}),))
result = bind_params("%s", ({'a': 'a', 'b': 'b'},))
self.assertEquals(result, "{ 'a' : 'a' , 'b' : 'b' }")
def test_quote_escaping(self):