diff --git a/cqlengine/columns.py b/cqlengine/columns.py index 15a02748..fe7ad9e3 100644 --- a/cqlengine/columns.py +++ b/cqlengine/columns.py @@ -1,7 +1,9 @@ #column field types +from copy import copy from datetime import datetime import re from uuid import uuid1, uuid4 +from cql.query import cql_quote from cqlengine.exceptions import ValidationError @@ -10,7 +12,7 @@ class BaseValueManager(object): def __init__(self, instance, column, value): self.instance = instance self.column = column - self.initial_value = value + self.initial_value = copy(value) self.value = value @property @@ -121,7 +123,7 @@ class Column(object): """ Returns a column definition for CQL table definition """ - return '{} {}'.format(self.db_field_name, self.db_type) + return '"{}" {}'.format(self.db_field_name, self.db_type) def set_column_name(self, name): """ @@ -156,6 +158,7 @@ class Text(Column): def validate(self, value): value = super(Text, self).validate(value) + if value is None: return if not isinstance(value, (basestring, bytearray)) and value is not None: raise ValidationError('{} is not a string'.format(type(value))) if self.max_length: @@ -171,6 +174,7 @@ class Integer(Column): def validate(self, value): val = super(Integer, self).validate(value) + if val is None: return try: return long(val) except (TypeError, ValueError): @@ -212,6 +216,7 @@ class UUID(Column): def validate(self, value): val = super(UUID, self).validate(value) + if val is None: return from uuid import UUID as _UUID if isinstance(val, _UUID): return val if not self.re_uuid.match(val): @@ -246,6 +251,8 @@ class Float(Column): super(Float, self).__init__(**kwargs) def validate(self, value): + value = super(Float, self).validate(value) + if value is None: return try: return float(value) except (TypeError, ValueError): @@ -266,3 +273,161 @@ class Counter(Column): super(Counter, self).__init__(**kwargs) raise NotImplementedError +class ContainerValueManager(BaseValueManager): + pass + +class ContainerQuoter(object): + """ + contains a single value, which will quote itself for CQL insertion statements + """ + def __init__(self, value): + self.value = value + + def __str__(self): + raise NotImplementedError + +class BaseContainerColumn(Column): + """ + Base Container type + """ + + def __init__(self, value_type, **kwargs): + """ + :param value_type: a column class indicating the types of the value + """ + if not issubclass(value_type, Column): + raise ValidationError('value_type must be a column class') + if issubclass(value_type, BaseContainerColumn): + raise ValidationError('container types cannot be nested') + if value_type.db_type is None: + raise ValidationError('value_type cannot be an abstract column type') + + self.value_type = value_type + self.value_col = self.value_type() + super(BaseContainerColumn, self).__init__(**kwargs) + + def get_column_def(self): + """ + Returns a column definition for CQL table definition + """ + db_type = self.db_type.format(self.value_type.db_type) + return '{} {}'.format(self.db_field_name, db_type) + +class Set(BaseContainerColumn): + """ + Stores a set of unordered, unique values + + http://www.datastax.com/docs/1.2/cql_cli/using/collections + """ + db_type = 'set<{}>' + + class Quoter(ContainerQuoter): + + def __str__(self): + cq = cql_quote + return '{' + ', '.join([cq(v) for v in self.value]) + '}' + + def __init__(self, value_type, strict=True, **kwargs): + """ + :param value_type: a column class indicating the types of the value + :param strict: sets whether non set values will be coerced to set + type on validation, or raise a validation error, defaults to True + """ + self.strict = strict + super(Set, self).__init__(value_type, **kwargs) + + def validate(self, value): + val = super(Set, self).validate(value) + if val is None: return + types = (set,) if self.strict else (set, list, tuple) + if not isinstance(val, types): + if self.strict: + raise ValidationError('{} is not a set object'.format(val)) + else: + raise ValidationError('{} cannot be coerced to a set object'.format(val)) + + return {self.value_col.validate(v) for v in val} + + def to_database(self, value): + return self.Quoter({self.value_col.to_database(v) for v in value}) + +class List(BaseContainerColumn): + """ + Stores a list of ordered values + + http://www.datastax.com/docs/1.2/cql_cli/using/collections_list + """ + db_type = 'list<{}>' + + class Quoter(ContainerQuoter): + + def __str__(self): + cq = cql_quote + return '[' + ', '.join([cq(v) for v in self.value]) + ']' + + def validate(self, value): + val = super(List, self).validate(value) + if val is None: return + if not isinstance(val, (set, list, tuple)): + raise ValidationError('{} is not a list object'.format(val)) + return [self.value_col.validate(v) for v in val] + + def to_database(self, value): + return self.Quoter([self.value_col.to_database(v) for v in value]) + +class Map(BaseContainerColumn): + """ + Stores a key -> value map (dictionary) + + http://www.datastax.com/docs/1.2/cql_cli/using/collections_map + """ + + db_type = 'map<{}, {}>' + + class Quoter(ContainerQuoter): + + def __str__(self): + cq = cql_quote + return '{' + ', '.join([cq(k) + ':' + cq(v) for k,v in self.value.items()]) + '}' + + def __init__(self, key_type, value_type, **kwargs): + """ + :param key_type: a column class indicating the types of the key + :param value_type: a column class indicating the types of the value + """ + if not issubclass(value_type, Column): + raise ValidationError('key_type must be a column class') + if issubclass(value_type, BaseContainerColumn): + raise ValidationError('container types cannot be nested') + if key_type.db_type is None: + raise ValidationError('key_type cannot be an abstract column type') + + self.key_type = key_type + self.key_col = self.key_type() + super(Map, self).__init__(value_type, **kwargs) + + def get_column_def(self): + """ + Returns a column definition for CQL table definition + """ + db_type = self.db_type.format( + self.key_type.db_type, + self.value_type.db_type + ) + return '{} {}'.format(self.db_field_name, db_type) + + def validate(self, value): + val = super(Map, self).validate(value) + if val is None: return + if not isinstance(val, dict): + raise ValidationError('{} is not a dict object'.format(val)) + return {self.key_col.validate(k):self.value_col.validate(v) for k,v in val.items()} + + def to_python(self, value): + if value is not None: + return {self.key_col.to_python(k):self.value_col.to_python(v) for k,v in value.items()} + + def to_database(self, value): + return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()}) + + diff --git a/cqlengine/management.py b/cqlengine/management.py index 20ec5741..ba8b8145 100644 --- a/cqlengine/management.py +++ b/cqlengine/management.py @@ -62,7 +62,7 @@ def create_table(model, create_missing_keyspace=True): pkeys = [] qtypes = [] def add_column(col): - s = '"{}" {}'.format(col.db_field_name, col.db_type) + s = col.get_column_def() if col.primary_key: pkeys.append('"{}"'.format(col.db_field_name)) qtypes.append(s) for name, col in model._columns.items(): diff --git a/cqlengine/tests/columns/test_container_columns.py b/cqlengine/tests/columns/test_container_columns.py new file mode 100644 index 00000000..064a7cc3 --- /dev/null +++ b/cqlengine/tests/columns/test_container_columns.py @@ -0,0 +1,116 @@ +from datetime import datetime, timedelta +from uuid import uuid4 + +from cqlengine import Model +from cqlengine import columns +from cqlengine.management import create_table, delete_table +from cqlengine.tests.base import BaseCassEngTestCase + +class TestSetModel(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + int_set = columns.Set(columns.Integer, required=False) + text_set = columns.Set(columns.Text, required=False) + +class TestSetColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestSetColumn, cls).setUpClass() + delete_table(TestSetModel) + create_table(TestSetModel) + + @classmethod + def tearDownClass(cls): + super(TestSetColumn, cls).tearDownClass() + delete_table(TestSetModel) + + def test_io_success(self): + """ Tests that a basic usage works as expected """ + m1 = TestSetModel.create(int_set={1,2}, text_set={'kai', 'andreas'}) + m2 = TestSetModel.get(partition=m1.partition) + + assert isinstance(m2.int_set, set) + assert isinstance(m2.text_set, set) + + assert 1 in m2.int_set + assert 2 in m2.int_set + + assert 'kai' in m2.text_set + assert 'andreas' in m2.text_set + + +class TestListModel(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + int_list = columns.List(columns.Integer, required=False) + text_list = columns.List(columns.Text, required=False) + +class TestListColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestListColumn, cls).setUpClass() + delete_table(TestListModel) + create_table(TestListModel) + + @classmethod + def tearDownClass(cls): + super(TestListColumn, cls).tearDownClass() + delete_table(TestListModel) + + def test_io_success(self): + """ Tests that a basic usage works as expected """ + m1 = TestListModel.create(int_list=[1,2], text_list=['kai', 'andreas']) + m2 = TestListModel.get(partition=m1.partition) + + assert isinstance(m2.int_list, tuple) + assert isinstance(m2.text_list, tuple) + + assert len(m2.int_list) == 2 + assert len(m2.text_list) == 2 + + assert m2.int_list[0] == 1 + assert m2.int_list[1] == 2 + + assert m2.text_list[0] == 'kai' + assert m2.text_list[1] == 'andreas' + + +class TestMapModel(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + int_map = columns.Map(columns.Integer, columns.UUID, required=False) + text_map = columns.Map(columns.Text, columns.DateTime, required=False) + +class TestMapColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestMapColumn, cls).setUpClass() + delete_table(TestMapModel) + create_table(TestMapModel) + + @classmethod + def tearDownClass(cls): + super(TestMapColumn, cls).tearDownClass() + delete_table(TestMapModel) + + def test_io_success(self): + """ Tests that a basic usage works as expected """ + k1 = uuid4() + k2 = uuid4() + now = datetime.now() + then = now + timedelta(days=1) + m1 = TestMapModel.create(int_map={1:k1,2:k2}, text_map={'now':now, 'then':then}) + m2 = TestMapModel.get(partition=m1.partition) + + assert isinstance(m2.int_map, dict) + assert isinstance(m2.text_map, dict) + + assert 1 in m2.int_map + assert 2 in m2.int_map + assert m2.int_map[1] == k1 + assert m2.int_map[2] == k2 + + assert 'now' in m2.text_map + assert 'then' in m2.text_map + assert (now - m2.text_map['now']).total_seconds() < 0.001 + assert (then - m2.text_map['then']).total_seconds() < 0.001