Make key sequences require type wrapping, not collections
This commit is contained in:
@@ -793,8 +793,9 @@ cql_encoders = {
|
||||
datetime.datetime: cql_encode_datetime,
|
||||
datetime.date: cql_encode_date,
|
||||
dict: cql_encode_map_collection,
|
||||
list: cql_encode_sequence,
|
||||
tuple: cql_encode_sequence,
|
||||
set: cql_encode_sequence,
|
||||
list: cql_encode_list_collection,
|
||||
tuple: cql_encode_list_collection,
|
||||
set: cql_encode_set_collection,
|
||||
frozenset: cql_encode_set_collection,
|
||||
types.GeneratorType: cql_encode_sequence
|
||||
}
|
||||
|
||||
@@ -2,9 +2,7 @@ import struct
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
from cassandra.decoder import (cql_encoders, cql_encode_object,
|
||||
cql_encode_map_collection,
|
||||
cql_encode_set_collection,
|
||||
cql_encode_list_collection)
|
||||
cql_encode_sequence)
|
||||
|
||||
class Query(object):
|
||||
|
||||
@@ -138,19 +136,13 @@ class BoundStatement(Query):
|
||||
return self._routing_key
|
||||
|
||||
|
||||
class ColumnCollection(object):
|
||||
class KeySequence(object):
|
||||
|
||||
def __init__(self, sequence):
|
||||
self.sequence = sequence
|
||||
|
||||
def __str__(self):
|
||||
s = self.sequence
|
||||
if isinstance(s, dict):
|
||||
return cql_encode_map_collection(s)
|
||||
elif isinstance(s, (set, frozenset)):
|
||||
return cql_encode_set_collection(s)
|
||||
else:
|
||||
return cql_encode_list_collection(s)
|
||||
return cql_encode_sequence(self.sequence)
|
||||
|
||||
|
||||
def bind_params(query, params):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
|
||||
from cassandra.query import bind_params, ColumnCollection
|
||||
from cassandra.query import bind_params, KeySequence
|
||||
|
||||
class ParamBindingTest(unittest.TestCase):
|
||||
|
||||
@@ -13,7 +13,7 @@ class ParamBindingTest(unittest.TestCase):
|
||||
self.assertEquals(result, "1 'a' 2.0")
|
||||
|
||||
def test_sequence_param(self):
|
||||
result = bind_params("%s", ((1, "a", 2.0),))
|
||||
result = bind_params("%s", (KeySequence((1, "a", 2.0)),))
|
||||
self.assertEquals(result, "( 1 , 'a' , 2.0 )")
|
||||
|
||||
def test_generator_param(self):
|
||||
@@ -25,15 +25,15 @@ class ParamBindingTest(unittest.TestCase):
|
||||
self.assertEquals(result, "NULL")
|
||||
|
||||
def test_list_collection(self):
|
||||
result = bind_params("%s", (ColumnCollection(['a', 'b', 'c']),))
|
||||
result = bind_params("%s", (['a', 'b', 'c'],))
|
||||
self.assertEquals(result, "[ 'a' , 'b' , 'c' ]")
|
||||
|
||||
def test_set_collection(self):
|
||||
result = bind_params("%s", (ColumnCollection({'a', 'b', 'c'}),))
|
||||
result = bind_params("%s", ({'a', 'b', 'c'},))
|
||||
self.assertEquals(result, "{ 'a' , 'c' , 'b' }")
|
||||
|
||||
def test_map_collection(self):
|
||||
result = bind_params("%s", (ColumnCollection({'a': 'a', 'b': 'b'}),))
|
||||
result = bind_params("%s", ({'a': 'a', 'b': 'b'},))
|
||||
self.assertEquals(result, "{ 'a' : 'a' , 'b' : 'b' }")
|
||||
|
||||
def test_quote_escaping(self):
|
||||
|
||||
Reference in New Issue
Block a user