cqle: refactor container init to common method

PYTHON-306
This commit is contained in:
Adam Holmberg
2016-02-11 11:36:56 -06:00
parent 3d66a3124f
commit f38c9125ed
2 changed files with 35 additions and 69 deletions

View File

@@ -16,6 +16,7 @@ from copy import deepcopy, copy
from datetime import date, datetime
import logging
import six
from uuid import UUID as _UUID
from cassandra import util
from cassandra.cqltypes import SimpleDateType
@@ -504,7 +505,6 @@ class UUID(Column):
val = super(UUID, self).validate(value)
if val is None:
return
from uuid import UUID as _UUID
if isinstance(val, _UUID):
return val
if isinstance(val, six.string_types):
@@ -523,9 +523,6 @@ class UUID(Column):
return self.validate(value)
from uuid import UUID as pyUUID, getnode
class TimeUUID(UUID):
"""
UUID containing timestamp
@@ -615,24 +612,22 @@ class BaseContainerColumn(Column):
https://cassandra.apache.org/doc/cql3/CQL.html#collections
"""
def __init__(self, value_type, **kwargs):
def __init__(self, types, **kwargs):
"""
:param value_type: a column class indicating the types of the value
:param types: a sequence of sub types in this collection
"""
inheritance_comparator = issubclass if isinstance(value_type, type) else isinstance
if not inheritance_comparator(value_type, Column):
raise ValidationError('value_type must be a column class')
if inheritance_comparator(value_type, BaseContainerColumn):
instances = []
for t in types:
inheritance_comparator = issubclass if isinstance(t, type) else isinstance
if not inheritance_comparator(t, Column):
raise ValidationError("%s is not a column class" % (t,))
if inheritance_comparator(t, BaseContainerColumn): # should go away with PYTHON-478
raise ValidationError('container types cannot be nested')
if value_type.db_type is None:
raise ValidationError('value_type cannot be an abstract column type')
if t.db_type is None:
raise ValidationError("%s is an abstract type" % (t,))
if isinstance(value_type, type):
self.value_type = value_type
self.value_col = self.value_type()
else:
self.value_col = value_type
self.value_type = self.value_col.__class__
instances.append(t() if isinstance(t, type) else t)
self.types = instances
super(BaseContainerColumn, self).__init__(**kwargs)
@@ -649,7 +644,7 @@ class BaseContainerColumn(Column):
@property
def sub_types(self):
return [self.value_col]
return self.types
class Set(BaseContainerColumn):
@@ -666,7 +661,10 @@ class Set(BaseContainerColumn):
"""
self.strict = strict
self.db_type = 'set<{0}>'.format(value_type.db_type)
super(Set, self).__init__(value_type, default=default, **kwargs)
super(Set, self).__init__((value_type,), default=default, **kwargs)
self.value_col = self.types[0]
def validate(self, value):
val = super(Set, self).validate(value)
@@ -706,7 +704,10 @@ class List(BaseContainerColumn):
:param value_type: a column class indicating the types of the value
"""
self.db_type = 'list<{0}>'.format(value_type.db_type)
return super(List, self).__init__(value_type=value_type, default=default, **kwargs)
super(List, self).__init__((value_type,), default=default, **kwargs)
self.value_col = self.types[0]
def validate(self, value):
val = super(List, self).validate(value)
@@ -743,21 +744,10 @@ class Map(BaseContainerColumn):
self.db_type = 'map<{0}, {1}>'.format(key_type.db_type, value_type.db_type)
inheritance_comparator = issubclass if isinstance(key_type, type) else isinstance
if not inheritance_comparator(key_type, Column):
raise ValidationError('key_type must be a column class')
if inheritance_comparator(key_type, BaseContainerColumn):
raise ValidationError('container types cannot be nested')
if key_type.db_type is None:
raise ValidationError('key_type cannot be an abstract column type')
super(Map, self).__init__((key_type, value_type), default=default, **kwargs)
if isinstance(key_type, type):
self.key_type = key_type
self.key_col = self.key_type()
else:
self.key_col = key_type
self.key_type = self.key_col.__class__
super(Map, self).__init__(value_type, default=default, **kwargs)
self.key_col = self.types[0]
self.value_col = self.types[1]
def validate(self, value):
val = super(Map, self).validate(value)
@@ -780,10 +770,6 @@ class Map(BaseContainerColumn):
return None
return dict((self.key_col.to_database(k), self.value_col.to_database(v)) for k, v in value.items())
@property
def sub_types(self):
return [self.key_col, self.value_col]
class UDTValueManager(BaseValueManager):
@property
@@ -796,7 +782,7 @@ class UDTValueManager(BaseValueManager):
self.previous_value = copy(self.value)
class Tuple(Column):
class Tuple(BaseContainerColumn):
"""
Stores a fixed-length set of positional values
@@ -812,46 +798,26 @@ class Tuple(Column):
if not args:
raise ValueError("Tuple must specify at least one inner type")
types = []
for arg in args:
inheritance_comparator = issubclass if isinstance(arg, type) else isinstance
if not inheritance_comparator(arg, Column):
raise ValidationError("%s is not a column class" % (arg,))
if arg.db_type is None:
raise ValidationError("%s is an abstract type" % (arg,))
types.append(arg() if isinstance(arg, type) else arg)
self.types = types
super(Tuple, self).__init__(**kwargs)
super(Tuple, self).__init__(args, **kwargs)
def validate(self, value):
val = super(Tuple, self).validate(value)
if val is None:
return
if len(val) > self.sub_types:
if len(val) > len(self.types):
raise ValidationError("Value %r has more fields than tuple definition (%s)" %
(val, ', '.join(t for t in self.sub_types)))
return tuple(t.validate(v) for t, v in zip(self.sub_types, val))
(val, ', '.join(t for t in self.types)))
return tuple(t.validate(v) for t, v in zip(self.types, val))
def to_python(self, value):
if value is None:
return tuple()
return tuple(t.to_python(v) for t, v in zip(self.sub_types, value))
return tuple(t.to_python(v) for t, v in zip(self.types, value))
def to_database(self, value):
if value is None:
return
return tuple(t.to_database(v) for t, v in zip(self.sub_types, value))
@property
def sub_types(self):
return self.types
# TODO:
# test UDTs in tuples, vice versa
# refactor init validation to common base
# cqlsh None in tuple should display as null
return tuple(t.to_database(v) for t, v in zip(self.types, value))
class UserDefinedType(Column):

View File

@@ -814,8 +814,8 @@ class ModelMetaClass(type):
raise ModelDefinitionException("column '{0}' conflicts with built-in attribute/method".format(k))
# counter column primary keys are not allowed
if (v.primary_key or v.partition_key) and isinstance(v, (columns.Counter, columns.BaseContainerColumn)):
raise ModelDefinitionException('counter columns and container columns cannot be used as primary keys')
if (v.primary_key or v.partition_key) and isinstance(v, columns.Counter):
raise ModelDefinitionException('counter columns cannot be used as primary keys')
# this will mark the first primary key column as a partition
# key, if one hasn't been set already