cqle: refactor container init to common method

PYTHON-306
This commit is contained in:
Adam Holmberg
2016-02-11 11:36:56 -06:00
parent 3d66a3124f
commit f38c9125ed
2 changed files with 35 additions and 69 deletions

View File

@@ -16,6 +16,7 @@ from copy import deepcopy, copy
from datetime import date, datetime from datetime import date, datetime
import logging import logging
import six import six
from uuid import UUID as _UUID
from cassandra import util from cassandra import util
from cassandra.cqltypes import SimpleDateType from cassandra.cqltypes import SimpleDateType
@@ -504,7 +505,6 @@ class UUID(Column):
val = super(UUID, self).validate(value) val = super(UUID, self).validate(value)
if val is None: if val is None:
return return
from uuid import UUID as _UUID
if isinstance(val, _UUID): if isinstance(val, _UUID):
return val return val
if isinstance(val, six.string_types): if isinstance(val, six.string_types):
@@ -523,9 +523,6 @@ class UUID(Column):
return self.validate(value) return self.validate(value)
from uuid import UUID as pyUUID, getnode
class TimeUUID(UUID): class TimeUUID(UUID):
""" """
UUID containing timestamp UUID containing timestamp
@@ -615,24 +612,22 @@ class BaseContainerColumn(Column):
https://cassandra.apache.org/doc/cql3/CQL.html#collections https://cassandra.apache.org/doc/cql3/CQL.html#collections
""" """
def __init__(self, value_type, **kwargs): def __init__(self, types, **kwargs):
""" """
:param value_type: a column class indicating the types of the value :param types: a sequence of sub types in this collection
""" """
inheritance_comparator = issubclass if isinstance(value_type, type) else isinstance instances = []
if not inheritance_comparator(value_type, Column): for t in types:
raise ValidationError('value_type must be a column class') inheritance_comparator = issubclass if isinstance(t, type) else isinstance
if inheritance_comparator(value_type, BaseContainerColumn): if not inheritance_comparator(t, Column):
raise ValidationError('container types cannot be nested') raise ValidationError("%s is not a column class" % (t,))
if value_type.db_type is None: if inheritance_comparator(t, BaseContainerColumn): # should go away with PYTHON-478
raise ValidationError('value_type cannot be an abstract column type') raise ValidationError('container types cannot be nested')
if t.db_type is None:
raise ValidationError("%s is an abstract type" % (t,))
if isinstance(value_type, type): instances.append(t() if isinstance(t, type) else t)
self.value_type = value_type self.types = instances
self.value_col = self.value_type()
else:
self.value_col = value_type
self.value_type = self.value_col.__class__
super(BaseContainerColumn, self).__init__(**kwargs) super(BaseContainerColumn, self).__init__(**kwargs)
@@ -649,7 +644,7 @@ class BaseContainerColumn(Column):
@property @property
def sub_types(self): def sub_types(self):
return [self.value_col] return self.types
class Set(BaseContainerColumn): class Set(BaseContainerColumn):
@@ -666,7 +661,10 @@ class Set(BaseContainerColumn):
""" """
self.strict = strict self.strict = strict
self.db_type = 'set<{0}>'.format(value_type.db_type) self.db_type = 'set<{0}>'.format(value_type.db_type)
super(Set, self).__init__(value_type, default=default, **kwargs)
super(Set, self).__init__((value_type,), default=default, **kwargs)
self.value_col = self.types[0]
def validate(self, value): def validate(self, value):
val = super(Set, self).validate(value) val = super(Set, self).validate(value)
@@ -706,7 +704,10 @@ class List(BaseContainerColumn):
:param value_type: a column class indicating the types of the value :param value_type: a column class indicating the types of the value
""" """
self.db_type = 'list<{0}>'.format(value_type.db_type) self.db_type = 'list<{0}>'.format(value_type.db_type)
return super(List, self).__init__(value_type=value_type, default=default, **kwargs)
super(List, self).__init__((value_type,), default=default, **kwargs)
self.value_col = self.types[0]
def validate(self, value): def validate(self, value):
val = super(List, self).validate(value) val = super(List, self).validate(value)
@@ -743,21 +744,10 @@ class Map(BaseContainerColumn):
self.db_type = 'map<{0}, {1}>'.format(key_type.db_type, value_type.db_type) self.db_type = 'map<{0}, {1}>'.format(key_type.db_type, value_type.db_type)
inheritance_comparator = issubclass if isinstance(key_type, type) else isinstance super(Map, self).__init__((key_type, value_type), default=default, **kwargs)
if not inheritance_comparator(key_type, Column):
raise ValidationError('key_type must be a column class')
if inheritance_comparator(key_type, BaseContainerColumn):
raise ValidationError('container types cannot be nested')
if key_type.db_type is None:
raise ValidationError('key_type cannot be an abstract column type')
if isinstance(key_type, type): self.key_col = self.types[0]
self.key_type = key_type self.value_col = self.types[1]
self.key_col = self.key_type()
else:
self.key_col = key_type
self.key_type = self.key_col.__class__
super(Map, self).__init__(value_type, default=default, **kwargs)
def validate(self, value): def validate(self, value):
val = super(Map, self).validate(value) val = super(Map, self).validate(value)
@@ -780,10 +770,6 @@ class Map(BaseContainerColumn):
return None return None
return dict((self.key_col.to_database(k), self.value_col.to_database(v)) for k, v in value.items()) return dict((self.key_col.to_database(k), self.value_col.to_database(v)) for k, v in value.items())
@property
def sub_types(self):
return [self.key_col, self.value_col]
class UDTValueManager(BaseValueManager): class UDTValueManager(BaseValueManager):
@property @property
@@ -796,7 +782,7 @@ class UDTValueManager(BaseValueManager):
self.previous_value = copy(self.value) self.previous_value = copy(self.value)
class Tuple(Column): class Tuple(BaseContainerColumn):
""" """
Stores a fixed-length set of positional values Stores a fixed-length set of positional values
@@ -812,46 +798,26 @@ class Tuple(Column):
if not args: if not args:
raise ValueError("Tuple must specify at least one inner type") raise ValueError("Tuple must specify at least one inner type")
types = [] super(Tuple, self).__init__(args, **kwargs)
for arg in args:
inheritance_comparator = issubclass if isinstance(arg, type) else isinstance
if not inheritance_comparator(arg, Column):
raise ValidationError("%s is not a column class" % (arg,))
if arg.db_type is None:
raise ValidationError("%s is an abstract type" % (arg,))
types.append(arg() if isinstance(arg, type) else arg)
self.types = types
super(Tuple, self).__init__(**kwargs)
def validate(self, value): def validate(self, value):
val = super(Tuple, self).validate(value) val = super(Tuple, self).validate(value)
if val is None: if val is None:
return return
if len(val) > self.sub_types: if len(val) > len(self.types):
raise ValidationError("Value %r has more fields than tuple definition (%s)" % raise ValidationError("Value %r has more fields than tuple definition (%s)" %
(val, ', '.join(t for t in self.sub_types))) (val, ', '.join(t for t in self.types)))
return tuple(t.validate(v) for t, v in zip(self.sub_types, val)) return tuple(t.validate(v) for t, v in zip(self.types, val))
def to_python(self, value): def to_python(self, value):
if value is None: if value is None:
return tuple() return tuple()
return tuple(t.to_python(v) for t, v in zip(self.sub_types, value)) return tuple(t.to_python(v) for t, v in zip(self.types, value))
def to_database(self, value): def to_database(self, value):
if value is None: if value is None:
return return
return tuple(t.to_database(v) for t, v in zip(self.sub_types, value)) return tuple(t.to_database(v) for t, v in zip(self.types, value))
@property
def sub_types(self):
return self.types
# TODO:
# test UDTs in tuples, vice versa
# refactor init validation to common base
# cqlsh None in tuple should display as null
class UserDefinedType(Column): class UserDefinedType(Column):

View File

@@ -814,8 +814,8 @@ class ModelMetaClass(type):
raise ModelDefinitionException("column '{0}' conflicts with built-in attribute/method".format(k)) raise ModelDefinitionException("column '{0}' conflicts with built-in attribute/method".format(k))
# counter column primary keys are not allowed # counter column primary keys are not allowed
if (v.primary_key or v.partition_key) and isinstance(v, (columns.Counter, columns.BaseContainerColumn)): if (v.primary_key or v.partition_key) and isinstance(v, columns.Counter):
raise ModelDefinitionException('counter columns and container columns cannot be used as primary keys') raise ModelDefinitionException('counter columns cannot be used as primary keys')
# this will mark the first primary key column as a partition # this will mark the first primary key column as a partition
# key, if one hasn't been set already # key, if one hasn't been set already