diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py index b3856c7e..2243be6b 100644 --- a/cassandra/cqlengine/columns.py +++ b/cassandra/cqlengine/columns.py @@ -16,6 +16,7 @@ from copy import deepcopy, copy from datetime import date, datetime import logging import six +from uuid import UUID as _UUID from cassandra import util from cassandra.cqltypes import SimpleDateType @@ -504,7 +505,6 @@ class UUID(Column): val = super(UUID, self).validate(value) if val is None: return - from uuid import UUID as _UUID if isinstance(val, _UUID): return val if isinstance(val, six.string_types): @@ -523,9 +523,6 @@ class UUID(Column): return self.validate(value) -from uuid import UUID as pyUUID, getnode - - class TimeUUID(UUID): """ UUID containing timestamp @@ -615,24 +612,22 @@ class BaseContainerColumn(Column): https://cassandra.apache.org/doc/cql3/CQL.html#collections """ - def __init__(self, value_type, **kwargs): + def __init__(self, types, **kwargs): """ - :param value_type: a column class indicating the types of the value + :param types: a sequence of sub types in this collection """ - inheritance_comparator = issubclass if isinstance(value_type, type) else isinstance - if not inheritance_comparator(value_type, Column): - raise ValidationError('value_type must be a column class') - if inheritance_comparator(value_type, BaseContainerColumn): - raise ValidationError('container types cannot be nested') - if value_type.db_type is None: - raise ValidationError('value_type cannot be an abstract column type') + instances = [] + for t in types: + inheritance_comparator = issubclass if isinstance(t, type) else isinstance + if not inheritance_comparator(t, Column): + raise ValidationError("%s is not a column class" % (t,)) + if inheritance_comparator(t, BaseContainerColumn): # should go away with PYTHON-478 + raise ValidationError('container types cannot be nested') + if t.db_type is None: + raise ValidationError("%s is an abstract type" % (t,)) - if isinstance(value_type, type): - self.value_type = value_type - self.value_col = self.value_type() - else: - self.value_col = value_type - self.value_type = self.value_col.__class__ + instances.append(t() if isinstance(t, type) else t) + self.types = instances super(BaseContainerColumn, self).__init__(**kwargs) @@ -649,7 +644,7 @@ class BaseContainerColumn(Column): @property def sub_types(self): - return [self.value_col] + return self.types class Set(BaseContainerColumn): @@ -666,7 +661,10 @@ class Set(BaseContainerColumn): """ self.strict = strict self.db_type = 'set<{0}>'.format(value_type.db_type) - super(Set, self).__init__(value_type, default=default, **kwargs) + + super(Set, self).__init__((value_type,), default=default, **kwargs) + + self.value_col = self.types[0] def validate(self, value): val = super(Set, self).validate(value) @@ -706,7 +704,10 @@ class List(BaseContainerColumn): :param value_type: a column class indicating the types of the value """ self.db_type = 'list<{0}>'.format(value_type.db_type) - return super(List, self).__init__(value_type=value_type, default=default, **kwargs) + + super(List, self).__init__((value_type,), default=default, **kwargs) + + self.value_col = self.types[0] def validate(self, value): val = super(List, self).validate(value) @@ -743,21 +744,10 @@ class Map(BaseContainerColumn): self.db_type = 'map<{0}, {1}>'.format(key_type.db_type, value_type.db_type) - inheritance_comparator = issubclass if isinstance(key_type, type) else isinstance - if not inheritance_comparator(key_type, Column): - raise ValidationError('key_type must be a column class') - if inheritance_comparator(key_type, BaseContainerColumn): - raise ValidationError('container types cannot be nested') - if key_type.db_type is None: - raise ValidationError('key_type cannot be an abstract column type') + super(Map, self).__init__((key_type, value_type), default=default, **kwargs) - if isinstance(key_type, type): - self.key_type = key_type - self.key_col = self.key_type() - else: - self.key_col = key_type - self.key_type = self.key_col.__class__ - super(Map, self).__init__(value_type, default=default, **kwargs) + self.key_col = self.types[0] + self.value_col = self.types[1] def validate(self, value): val = super(Map, self).validate(value) @@ -780,10 +770,6 @@ class Map(BaseContainerColumn): return None return dict((self.key_col.to_database(k), self.value_col.to_database(v)) for k, v in value.items()) - @property - def sub_types(self): - return [self.key_col, self.value_col] - class UDTValueManager(BaseValueManager): @property @@ -796,7 +782,7 @@ class UDTValueManager(BaseValueManager): self.previous_value = copy(self.value) -class Tuple(Column): +class Tuple(BaseContainerColumn): """ Stores a fixed-length set of positional values @@ -812,46 +798,26 @@ class Tuple(Column): if not args: raise ValueError("Tuple must specify at least one inner type") - types = [] - for arg in args: - inheritance_comparator = issubclass if isinstance(arg, type) else isinstance - if not inheritance_comparator(arg, Column): - raise ValidationError("%s is not a column class" % (arg,)) - if arg.db_type is None: - raise ValidationError("%s is an abstract type" % (arg,)) - - types.append(arg() if isinstance(arg, type) else arg) - self.types = types - - super(Tuple, self).__init__(**kwargs) + super(Tuple, self).__init__(args, **kwargs) def validate(self, value): val = super(Tuple, self).validate(value) if val is None: return - if len(val) > self.sub_types: + if len(val) > len(self.types): raise ValidationError("Value %r has more fields than tuple definition (%s)" % - (val, ', '.join(t for t in self.sub_types))) - return tuple(t.validate(v) for t, v in zip(self.sub_types, val)) + (val, ', '.join(t for t in self.types))) + return tuple(t.validate(v) for t, v in zip(self.types, val)) def to_python(self, value): if value is None: return tuple() - return tuple(t.to_python(v) for t, v in zip(self.sub_types, value)) + return tuple(t.to_python(v) for t, v in zip(self.types, value)) def to_database(self, value): if value is None: return - return tuple(t.to_database(v) for t, v in zip(self.sub_types, value)) - - @property - def sub_types(self): - return self.types - -# TODO: -# test UDTs in tuples, vice versa -# refactor init validation to common base -# cqlsh None in tuple should display as null + return tuple(t.to_database(v) for t, v in zip(self.types, value)) class UserDefinedType(Column): diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index 3eeef449..e80ed7da 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -814,8 +814,8 @@ class ModelMetaClass(type): raise ModelDefinitionException("column '{0}' conflicts with built-in attribute/method".format(k)) # counter column primary keys are not allowed - if (v.primary_key or v.partition_key) and isinstance(v, (columns.Counter, columns.BaseContainerColumn)): - raise ModelDefinitionException('counter columns and container columns cannot be used as primary keys') + if (v.primary_key or v.partition_key) and isinstance(v, columns.Counter): + raise ModelDefinitionException('counter columns cannot be used as primary keys') # this will mark the first primary key column as a partition # key, if one hasn't been set already