Merge pull request #305 from datastax/PYTHON-311
PYTHON-311 - Correct CQL encoding for collections in cqlengine
This commit is contained in:
@@ -73,25 +73,6 @@ class BaseValueManager(object):
|
||||
return property(_get, _set)
|
||||
|
||||
|
||||
class ValueQuoter(object):
|
||||
"""
|
||||
contains a single value, which will quote itself for CQL insertion statements
|
||||
"""
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return self.value == other.value
|
||||
return False
|
||||
|
||||
|
||||
class Column(object):
|
||||
|
||||
# the cassandra type this column maps to
|
||||
@@ -715,24 +696,12 @@ class BaseContainerColumn(Column):
|
||||
return [self.value_col]
|
||||
|
||||
|
||||
class BaseContainerQuoter(ValueQuoter):
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self.value)
|
||||
|
||||
|
||||
class Set(BaseContainerColumn):
|
||||
"""
|
||||
Stores a set of unordered, unique values
|
||||
|
||||
http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_set_t.html
|
||||
"""
|
||||
class Quoter(BaseContainerQuoter):
|
||||
|
||||
def __str__(self):
|
||||
cq = cql_quote
|
||||
return '{' + ', '.join([cq(v) for v in self.value]) + '}'
|
||||
|
||||
def __init__(self, value_type, strict=True, default=set, **kwargs):
|
||||
"""
|
||||
:param value_type: a column class indicating the types of the value
|
||||
@@ -767,10 +736,7 @@ class Set(BaseContainerColumn):
|
||||
def to_database(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, self.Quoter):
|
||||
return value
|
||||
return self.Quoter({self.value_col.to_database(v) for v in value})
|
||||
return {self.value_col.to_database(v) for v in value}
|
||||
|
||||
|
||||
class List(BaseContainerColumn):
|
||||
@@ -779,15 +745,6 @@ class List(BaseContainerColumn):
|
||||
|
||||
http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_list_t.html
|
||||
"""
|
||||
class Quoter(BaseContainerQuoter):
|
||||
|
||||
def __str__(self):
|
||||
cq = cql_quote
|
||||
return '[' + ', '.join([cq(v) for v in self.value]) + ']'
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __init__(self, value_type, default=list, **kwargs):
|
||||
"""
|
||||
:param value_type: a column class indicating the types of the value
|
||||
@@ -813,9 +770,7 @@ class List(BaseContainerColumn):
|
||||
def to_database(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self.Quoter):
|
||||
return value
|
||||
return self.Quoter([self.value_col.to_database(v) for v in value])
|
||||
return [self.value_col.to_database(v) for v in value]
|
||||
|
||||
|
||||
class Map(BaseContainerColumn):
|
||||
@@ -824,21 +779,6 @@ class Map(BaseContainerColumn):
|
||||
|
||||
http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_map_t.html
|
||||
"""
|
||||
class Quoter(BaseContainerQuoter):
|
||||
|
||||
def __str__(self):
|
||||
cq = cql_quote
|
||||
return '{' + ', '.join([cq(k) + ':' + cq(v) for k, v in self.value.items()]) + '}'
|
||||
|
||||
def get(self, key):
|
||||
return self.value.get(key)
|
||||
|
||||
def keys(self):
|
||||
return self.value.keys()
|
||||
|
||||
def items(self):
|
||||
return self.value.items()
|
||||
|
||||
def __init__(self, key_type, value_type, default=dict, **kwargs):
|
||||
"""
|
||||
:param key_type: a column class indicating the types of the key
|
||||
@@ -880,9 +820,7 @@ class Map(BaseContainerColumn):
|
||||
def to_database(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self.Quoter):
|
||||
return value
|
||||
return self.Quoter({self.key_col.to_database(k): self.value_col.to_database(v) for k, v in value.items()})
|
||||
return {self.key_col.to_database(k): self.value_col.to_database(v) for k, v in value.items()}
|
||||
|
||||
@property
|
||||
def sub_columns(self):
|
||||
|
@@ -312,6 +312,8 @@ def _sync_type(ks_name, type_model, omit_subtypes=None):
|
||||
if field.db_field_name not in defined_fields:
|
||||
execute("ALTER TYPE {} ADD {}".format(type_name_qualified, field.get_column_def()))
|
||||
|
||||
type_model.register_for_keyspace(ks_name)
|
||||
|
||||
if len(defined_fields) == len(model_fields):
|
||||
log.info("Type %s did not require synchronization", type_name_qualified)
|
||||
return
|
||||
@@ -320,8 +322,6 @@ def _sync_type(ks_name, type_model, omit_subtypes=None):
|
||||
if db_fields_not_in_model:
|
||||
log.info("Type %s has fields not referenced by model: %s", type_name_qualified, db_fields_not_in_model)
|
||||
|
||||
type_model.register_for_keyspace(ks_name)
|
||||
|
||||
|
||||
def get_create_type(type_model, keyspace):
|
||||
type_meta = metadata.UserType(keyspace,
|
||||
|
@@ -658,7 +658,7 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
|
||||
|
||||
class InsertStatement(AssignmentStatement):
|
||||
""" an cql insert select statement """
|
||||
""" an cql insert statement """
|
||||
|
||||
def __init__(self,
|
||||
table,
|
||||
|
@@ -161,8 +161,8 @@ class TestSetColumn(BaseCassEngTestCase):
|
||||
column = columns.Set(JsonTestColumn)
|
||||
val = {1, 2, 3}
|
||||
db_val = column.to_database(val)
|
||||
assert db_val.value == {json.dumps(v) for v in val}
|
||||
py_val = column.to_python(db_val.value)
|
||||
assert db_val == {json.dumps(v) for v in val}
|
||||
py_val = column.to_python(db_val)
|
||||
assert py_val == val
|
||||
|
||||
def test_default_empty_container_saving(self):
|
||||
@@ -277,8 +277,8 @@ class TestListColumn(BaseCassEngTestCase):
|
||||
column = columns.List(JsonTestColumn)
|
||||
val = [1, 2, 3]
|
||||
db_val = column.to_database(val)
|
||||
assert db_val.value == [json.dumps(v) for v in val]
|
||||
py_val = column.to_python(db_val.value)
|
||||
assert db_val == [json.dumps(v) for v in val]
|
||||
py_val = column.to_python(db_val)
|
||||
assert py_val == val
|
||||
|
||||
def test_default_empty_container_saving(self):
|
||||
@@ -495,8 +495,8 @@ class TestMapColumn(BaseCassEngTestCase):
|
||||
column = columns.Map(JsonTestColumn, JsonTestColumn)
|
||||
val = {1: 2, 3: 4, 5: 6}
|
||||
db_val = column.to_database(val)
|
||||
assert db_val.value == {json.dumps(k):json.dumps(v) for k,v in val.items()}
|
||||
py_val = column.to_python(db_val.value)
|
||||
assert db_val == {json.dumps(k):json.dumps(v) for k,v in val.items()}
|
||||
py_val = column.to_python(db_val)
|
||||
assert py_val == val
|
||||
|
||||
def test_default_empty_container_saving(self):
|
||||
|
@@ -22,7 +22,6 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase
|
||||
from cassandra.cqlengine.management import sync_table
|
||||
from cassandra.cqlengine.management import drop_table
|
||||
from cassandra.cqlengine.models import Model
|
||||
from cassandra.cqlengine.columns import ValueQuoter
|
||||
from cassandra.cqlengine import columns
|
||||
import unittest
|
||||
|
||||
@@ -211,11 +210,3 @@ class TestDecimalIO(BaseColumnIOTest):
|
||||
def comparator_converter(self, val):
|
||||
return Decimal(val)
|
||||
|
||||
|
||||
class TestQuoter(unittest.TestCase):
|
||||
|
||||
def test_equals(self):
|
||||
assert ValueQuoter(False) == ValueQuoter(False)
|
||||
assert ValueQuoter(1) == ValueQuoter(1)
|
||||
assert ValueQuoter("foo") == ValueQuoter("foo")
|
||||
assert ValueQuoter(1.55) == ValueQuoter(1.55)
|
||||
|
@@ -60,25 +60,25 @@ class UpdateStatementTests(TestCase):
|
||||
|
||||
def test_update_set_add(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(SetUpdateClause('a', Set.Quoter({1}), operation='add'))
|
||||
us.add_assignment_clause(SetUpdateClause('a', {1}, operation='add'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
|
||||
|
||||
def test_update_empty_set_add_does_not_assign(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(SetUpdateClause('a', Set.Quoter(set()), operation='add'))
|
||||
us.add_assignment_clause(SetUpdateClause('a', set(), operation='add'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
|
||||
|
||||
def test_update_empty_set_removal_does_not_assign(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(SetUpdateClause('a', Set.Quoter(set()), operation='remove'))
|
||||
us.add_assignment_clause(SetUpdateClause('a', set(), operation='remove'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" - %(0)s')
|
||||
|
||||
def test_update_list_prepend_with_empty_list(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(ListUpdateClause('a', List.Quoter([]), operation='prepend'))
|
||||
us.add_assignment_clause(ListUpdateClause('a', [], operation='prepend'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s + "a"')
|
||||
|
||||
def test_update_list_append_with_empty_list(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(ListUpdateClause('a', List.Quoter([]), operation='append'))
|
||||
us.add_assignment_clause(ListUpdateClause('a', [], operation='append'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
|
||||
|
Reference in New Issue
Block a user