Merge pull request #305 from datastax/PYTHON-311

PYTHON-311 - Correct CQL encoding for collections in cqlengine
This commit is contained in:
Adam Holmberg
2015-05-19 13:38:23 -05:00
6 changed files with 17 additions and 88 deletions

View File

@@ -73,25 +73,6 @@ class BaseValueManager(object):
return property(_get, _set) return property(_get, _set)
class ValueQuoter(object):
"""
contains a single value, which will quote itself for CQL insertion statements
"""
def __init__(self, value):
self.value = value
def __str__(self):
raise NotImplementedError
def __repr__(self):
return self.__str__()
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.value == other.value
return False
class Column(object): class Column(object):
# the cassandra type this column maps to # the cassandra type this column maps to
@@ -715,24 +696,12 @@ class BaseContainerColumn(Column):
return [self.value_col] return [self.value_col]
class BaseContainerQuoter(ValueQuoter):
def __nonzero__(self):
return bool(self.value)
class Set(BaseContainerColumn): class Set(BaseContainerColumn):
""" """
Stores a set of unordered, unique values Stores a set of unordered, unique values
http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_set_t.html http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_set_t.html
""" """
class Quoter(BaseContainerQuoter):
def __str__(self):
cq = cql_quote
return '{' + ', '.join([cq(v) for v in self.value]) + '}'
def __init__(self, value_type, strict=True, default=set, **kwargs): def __init__(self, value_type, strict=True, default=set, **kwargs):
""" """
:param value_type: a column class indicating the types of the value :param value_type: a column class indicating the types of the value
@@ -767,10 +736,7 @@ class Set(BaseContainerColumn):
def to_database(self, value): def to_database(self, value):
if value is None: if value is None:
return None return None
return {self.value_col.to_database(v) for v in value}
if isinstance(value, self.Quoter):
return value
return self.Quoter({self.value_col.to_database(v) for v in value})
class List(BaseContainerColumn): class List(BaseContainerColumn):
@@ -779,15 +745,6 @@ class List(BaseContainerColumn):
http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_list_t.html http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_list_t.html
""" """
class Quoter(BaseContainerQuoter):
def __str__(self):
cq = cql_quote
return '[' + ', '.join([cq(v) for v in self.value]) + ']'
def __nonzero__(self):
return bool(self.value)
def __init__(self, value_type, default=list, **kwargs): def __init__(self, value_type, default=list, **kwargs):
""" """
:param value_type: a column class indicating the types of the value :param value_type: a column class indicating the types of the value
@@ -813,9 +770,7 @@ class List(BaseContainerColumn):
def to_database(self, value): def to_database(self, value):
if value is None: if value is None:
return None return None
if isinstance(value, self.Quoter): return [self.value_col.to_database(v) for v in value]
return value
return self.Quoter([self.value_col.to_database(v) for v in value])
class Map(BaseContainerColumn): class Map(BaseContainerColumn):
@@ -824,21 +779,6 @@ class Map(BaseContainerColumn):
http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_map_t.html http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_map_t.html
""" """
class Quoter(BaseContainerQuoter):
def __str__(self):
cq = cql_quote
return '{' + ', '.join([cq(k) + ':' + cq(v) for k, v in self.value.items()]) + '}'
def get(self, key):
return self.value.get(key)
def keys(self):
return self.value.keys()
def items(self):
return self.value.items()
def __init__(self, key_type, value_type, default=dict, **kwargs): def __init__(self, key_type, value_type, default=dict, **kwargs):
""" """
:param key_type: a column class indicating the types of the key :param key_type: a column class indicating the types of the key
@@ -880,9 +820,7 @@ class Map(BaseContainerColumn):
def to_database(self, value): def to_database(self, value):
if value is None: if value is None:
return None return None
if isinstance(value, self.Quoter): return {self.key_col.to_database(k): self.value_col.to_database(v) for k, v in value.items()}
return value
return self.Quoter({self.key_col.to_database(k): self.value_col.to_database(v) for k, v in value.items()})
@property @property
def sub_columns(self): def sub_columns(self):

View File

@@ -312,6 +312,8 @@ def _sync_type(ks_name, type_model, omit_subtypes=None):
if field.db_field_name not in defined_fields: if field.db_field_name not in defined_fields:
execute("ALTER TYPE {} ADD {}".format(type_name_qualified, field.get_column_def())) execute("ALTER TYPE {} ADD {}".format(type_name_qualified, field.get_column_def()))
type_model.register_for_keyspace(ks_name)
if len(defined_fields) == len(model_fields): if len(defined_fields) == len(model_fields):
log.info("Type %s did not require synchronization", type_name_qualified) log.info("Type %s did not require synchronization", type_name_qualified)
return return
@@ -320,8 +322,6 @@ def _sync_type(ks_name, type_model, omit_subtypes=None):
if db_fields_not_in_model: if db_fields_not_in_model:
log.info("Type %s has fields not referenced by model: %s", type_name_qualified, db_fields_not_in_model) log.info("Type %s has fields not referenced by model: %s", type_name_qualified, db_fields_not_in_model)
type_model.register_for_keyspace(ks_name)
def get_create_type(type_model, keyspace): def get_create_type(type_model, keyspace):
type_meta = metadata.UserType(keyspace, type_meta = metadata.UserType(keyspace,

View File

@@ -658,7 +658,7 @@ class AssignmentStatement(BaseCQLStatement):
class InsertStatement(AssignmentStatement): class InsertStatement(AssignmentStatement):
""" an cql insert select statement """ """ an cql insert statement """
def __init__(self, def __init__(self,
table, table,

View File

@@ -161,8 +161,8 @@ class TestSetColumn(BaseCassEngTestCase):
column = columns.Set(JsonTestColumn) column = columns.Set(JsonTestColumn)
val = {1, 2, 3} val = {1, 2, 3}
db_val = column.to_database(val) db_val = column.to_database(val)
assert db_val.value == {json.dumps(v) for v in val} assert db_val == {json.dumps(v) for v in val}
py_val = column.to_python(db_val.value) py_val = column.to_python(db_val)
assert py_val == val assert py_val == val
def test_default_empty_container_saving(self): def test_default_empty_container_saving(self):
@@ -277,8 +277,8 @@ class TestListColumn(BaseCassEngTestCase):
column = columns.List(JsonTestColumn) column = columns.List(JsonTestColumn)
val = [1, 2, 3] val = [1, 2, 3]
db_val = column.to_database(val) db_val = column.to_database(val)
assert db_val.value == [json.dumps(v) for v in val] assert db_val == [json.dumps(v) for v in val]
py_val = column.to_python(db_val.value) py_val = column.to_python(db_val)
assert py_val == val assert py_val == val
def test_default_empty_container_saving(self): def test_default_empty_container_saving(self):
@@ -495,8 +495,8 @@ class TestMapColumn(BaseCassEngTestCase):
column = columns.Map(JsonTestColumn, JsonTestColumn) column = columns.Map(JsonTestColumn, JsonTestColumn)
val = {1: 2, 3: 4, 5: 6} val = {1: 2, 3: 4, 5: 6}
db_val = column.to_database(val) db_val = column.to_database(val)
assert db_val.value == {json.dumps(k):json.dumps(v) for k,v in val.items()} assert db_val == {json.dumps(k):json.dumps(v) for k,v in val.items()}
py_val = column.to_python(db_val.value) py_val = column.to_python(db_val)
assert py_val == val assert py_val == val
def test_default_empty_container_saving(self): def test_default_empty_container_saving(self):

View File

@@ -22,7 +22,6 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase
from cassandra.cqlengine.management import sync_table from cassandra.cqlengine.management import sync_table
from cassandra.cqlengine.management import drop_table from cassandra.cqlengine.management import drop_table
from cassandra.cqlengine.models import Model from cassandra.cqlengine.models import Model
from cassandra.cqlengine.columns import ValueQuoter
from cassandra.cqlengine import columns from cassandra.cqlengine import columns
import unittest import unittest
@@ -211,11 +210,3 @@ class TestDecimalIO(BaseColumnIOTest):
def comparator_converter(self, val): def comparator_converter(self, val):
return Decimal(val) return Decimal(val)
class TestQuoter(unittest.TestCase):
def test_equals(self):
assert ValueQuoter(False) == ValueQuoter(False)
assert ValueQuoter(1) == ValueQuoter(1)
assert ValueQuoter("foo") == ValueQuoter("foo")
assert ValueQuoter(1.55) == ValueQuoter(1.55)

View File

@@ -60,25 +60,25 @@ class UpdateStatementTests(TestCase):
def test_update_set_add(self): def test_update_set_add(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', Set.Quoter({1}), operation='add')) us.add_assignment_clause(SetUpdateClause('a', {1}, operation='add'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
def test_update_empty_set_add_does_not_assign(self): def test_update_empty_set_add_does_not_assign(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', Set.Quoter(set()), operation='add')) us.add_assignment_clause(SetUpdateClause('a', set(), operation='add'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
def test_update_empty_set_removal_does_not_assign(self): def test_update_empty_set_removal_does_not_assign(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', Set.Quoter(set()), operation='remove')) us.add_assignment_clause(SetUpdateClause('a', set(), operation='remove'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" - %(0)s') self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" - %(0)s')
def test_update_list_prepend_with_empty_list(self): def test_update_list_prepend_with_empty_list(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(ListUpdateClause('a', List.Quoter([]), operation='prepend')) us.add_assignment_clause(ListUpdateClause('a', [], operation='prepend'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s + "a"') self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s + "a"')
def test_update_list_append_with_empty_list(self): def test_update_list_append_with_empty_list(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(ListUpdateClause('a', List.Quoter([]), operation='append')) us.add_assignment_clause(ListUpdateClause('a', [], operation='append'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')