851 lines
29 KiB
Python
851 lines
29 KiB
Python
from collections import OrderedDict
|
|
import re
|
|
import warnings
|
|
|
|
from cqlengine import columns
|
|
from cqlengine.exceptions import ModelException, CQLEngineException, ValidationError
|
|
from cqlengine.query import ModelQuerySet, DMLQuery, AbstractQueryableColumn, NOT_SET
|
|
from cqlengine.query import DoesNotExist as _DoesNotExist
|
|
from cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
|
|
|
|
class ModelDefinitionException(ModelException): pass
|
|
|
|
|
|
class PolyMorphicModelException(ModelException): pass
|
|
|
|
DEFAULT_KEYSPACE = None
|
|
|
|
class UndefinedKeyspaceWarning(Warning):
|
|
pass
|
|
|
|
|
|
class hybrid_classmethod(object):
|
|
"""
|
|
Allows a method to behave as both a class method and
|
|
normal instance method depending on how it's called
|
|
"""
|
|
|
|
def __init__(self, clsmethod, instmethod):
|
|
self.clsmethod = clsmethod
|
|
self.instmethod = instmethod
|
|
|
|
def __get__(self, instance, owner):
|
|
if instance is None:
|
|
return self.clsmethod.__get__(owner, owner)
|
|
else:
|
|
return self.instmethod.__get__(instance, owner)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
"""
|
|
Just a hint to IDEs that it's ok to call this
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class QuerySetDescriptor(object):
|
|
"""
|
|
returns a fresh queryset for the given model
|
|
it's declared on everytime it's accessed
|
|
"""
|
|
|
|
def __get__(self, obj, model):
|
|
""" :rtype: ModelQuerySet """
|
|
if model.__abstract__:
|
|
raise CQLEngineException('cannot execute queries against abstract models')
|
|
queryset = model.__queryset__(model)
|
|
|
|
# if this is a concrete polymorphic model, and the polymorphic
|
|
# key is an indexed column, add a filter clause to only return
|
|
# logical rows of the proper type
|
|
if model._is_polymorphic and not model._is_polymorphic_base:
|
|
name, column = model._polymorphic_column_name, model._polymorphic_column
|
|
if column.partition_key or column.index:
|
|
# look for existing poly types
|
|
return queryset.filter(**{name: model.__polymorphic_key__})
|
|
|
|
return queryset
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
"""
|
|
Just a hint to IDEs that it's ok to call this
|
|
|
|
:rtype: ModelQuerySet
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class TransactionDescriptor(object):
|
|
"""
|
|
returns a query set descriptor
|
|
"""
|
|
def __get__(self, instance, model):
|
|
if instance:
|
|
def transaction_setter(*prepared_transaction, **unprepared_transactions):
|
|
if len(prepared_transaction) > 0:
|
|
transactions = prepared_transaction[0]
|
|
else:
|
|
transactions = instance.objects.iff(**unprepared_transactions)._transaction
|
|
instance._transaction = transactions
|
|
return instance
|
|
|
|
return transaction_setter
|
|
qs = model.__queryset__(model)
|
|
|
|
def transaction_setter(**unprepared_transactions):
|
|
transactions = model.objects.iff(**unprepared_transactions)._transaction
|
|
qs._transaction = transactions
|
|
return qs
|
|
return transaction_setter
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
class TTLDescriptor(object):
|
|
"""
|
|
returns a query set descriptor
|
|
"""
|
|
def __get__(self, instance, model):
|
|
if instance:
|
|
#instance = copy.deepcopy(instance)
|
|
# instance method
|
|
def ttl_setter(ts):
|
|
instance._ttl = ts
|
|
return instance
|
|
return ttl_setter
|
|
|
|
qs = model.__queryset__(model)
|
|
|
|
def ttl_setter(ts):
|
|
qs._ttl = ts
|
|
return qs
|
|
|
|
return ttl_setter
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
class TimestampDescriptor(object):
|
|
"""
|
|
returns a query set descriptor with a timestamp specified
|
|
"""
|
|
def __get__(self, instance, model):
|
|
if instance:
|
|
# instance method
|
|
def timestamp_setter(ts):
|
|
instance._timestamp = ts
|
|
return instance
|
|
return timestamp_setter
|
|
|
|
return model.objects.timestamp
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
class IfNotExistsDescriptor(object):
|
|
"""
|
|
return a query set descriptor with a if_not_exists flag specified
|
|
"""
|
|
def __get__(self, instance, model):
|
|
if instance:
|
|
# instance method
|
|
def ifnotexists_setter(ife):
|
|
instance._if_not_exists = ife
|
|
return instance
|
|
return ifnotexists_setter
|
|
|
|
return model.objects.if_not_exists
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
class ConsistencyDescriptor(object):
|
|
"""
|
|
returns a query set descriptor if called on Class, instance if it was an instance call
|
|
"""
|
|
def __get__(self, instance, model):
|
|
if instance:
|
|
#instance = copy.deepcopy(instance)
|
|
def consistency_setter(consistency):
|
|
instance.__consistency__ = consistency
|
|
return instance
|
|
return consistency_setter
|
|
|
|
qs = model.__queryset__(model)
|
|
|
|
def consistency_setter(consistency):
|
|
qs._consistency = consistency
|
|
return qs
|
|
|
|
return consistency_setter
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
class ColumnQueryEvaluator(AbstractQueryableColumn):
|
|
"""
|
|
Wraps a column and allows it to be used in comparator
|
|
expressions, returning query operators
|
|
|
|
ie:
|
|
Model.column == 5
|
|
"""
|
|
|
|
def __init__(self, column):
|
|
self.column = column
|
|
|
|
def __unicode__(self):
|
|
return self.column.db_field_name
|
|
|
|
def _get_column(self):
|
|
""" :rtype: ColumnQueryEvaluator """
|
|
return self.column
|
|
|
|
|
|
class ColumnDescriptor(object):
|
|
"""
|
|
Handles the reading and writing of column values to and from
|
|
a model instance's value manager, as well as creating
|
|
comparator queries
|
|
"""
|
|
|
|
def __init__(self, column):
|
|
"""
|
|
:param column:
|
|
:type column: columns.Column
|
|
:return:
|
|
"""
|
|
self.column = column
|
|
self.query_evaluator = ColumnQueryEvaluator(self.column)
|
|
|
|
def __get__(self, instance, owner):
|
|
"""
|
|
Returns either the value or column, depending
|
|
on if an instance is provided or not
|
|
|
|
:param instance: the model instance
|
|
:type instance: Model
|
|
"""
|
|
try:
|
|
return instance._values[self.column.column_name].getval()
|
|
except AttributeError as e:
|
|
return self.query_evaluator
|
|
|
|
def __set__(self, instance, value):
|
|
"""
|
|
Sets the value on an instance, raises an exception with classes
|
|
TODO: use None instance to create update statements
|
|
"""
|
|
if instance:
|
|
return instance._values[self.column.column_name].setval(value)
|
|
else:
|
|
raise AttributeError('cannot reassign column values')
|
|
|
|
def __delete__(self, instance):
|
|
"""
|
|
Sets the column value to None, if possible
|
|
"""
|
|
if instance:
|
|
if self.column.can_delete:
|
|
instance._values[self.column.column_name].delval()
|
|
else:
|
|
raise AttributeError('cannot delete {} columns'.format(self.column.column_name))
|
|
|
|
|
|
class BaseModel(object):
|
|
"""
|
|
The base model class, don't inherit from this, inherit from Model, defined below
|
|
"""
|
|
|
|
class DoesNotExist(_DoesNotExist): pass
|
|
|
|
class MultipleObjectsReturned(_MultipleObjectsReturned): pass
|
|
|
|
objects = QuerySetDescriptor()
|
|
ttl = TTLDescriptor()
|
|
consistency = ConsistencyDescriptor()
|
|
iff = TransactionDescriptor()
|
|
|
|
# custom timestamps, see USING TIMESTAMP X
|
|
timestamp = TimestampDescriptor()
|
|
|
|
if_not_exists = IfNotExistsDescriptor()
|
|
|
|
# _len is lazily created by __len__
|
|
|
|
# table names will be generated automatically from it's model
|
|
# however, you can also define them manually here
|
|
__table_name__ = None
|
|
|
|
# the keyspace for this model
|
|
__keyspace__ = None
|
|
|
|
# polymorphism options
|
|
__polymorphic_key__ = None
|
|
|
|
# compaction options
|
|
__compaction__ = None
|
|
__compaction_tombstone_compaction_interval__ = None
|
|
__compaction_tombstone_threshold__ = None
|
|
|
|
# compaction - size tiered options
|
|
__compaction_bucket_high__ = None
|
|
__compaction_bucket_low__ = None
|
|
__compaction_max_threshold__ = None
|
|
__compaction_min_threshold__ = None
|
|
__compaction_min_sstable_size__ = None
|
|
|
|
# compaction - leveled options
|
|
__compaction_sstable_size_in_mb__ = None
|
|
|
|
# end compaction
|
|
# the queryset class used for this class
|
|
__queryset__ = ModelQuerySet
|
|
__dmlquery__ = DMLQuery
|
|
|
|
__default_ttl__ = None # default ttl value to use
|
|
__consistency__ = None # can be set per query
|
|
|
|
# Additional table properties
|
|
__bloom_filter_fp_chance__ = None
|
|
__caching__ = None
|
|
__comment__ = None
|
|
__dclocal_read_repair_chance__ = None
|
|
__default_time_to_live__ = None
|
|
__gc_grace_seconds__ = None
|
|
__index_interval__ = None
|
|
__memtable_flush_period_in_ms__ = None
|
|
__populate_io_cache_on_flush__ = None
|
|
__read_repair_chance__ = None
|
|
__replicate_on_write__ = None
|
|
|
|
_timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP)
|
|
|
|
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
|
|
|
|
def __init__(self, **values):
|
|
self._values = {}
|
|
self._ttl = self.__default_ttl__
|
|
self._timestamp = None
|
|
self._transaction = None
|
|
|
|
for name, column in self._columns.items():
|
|
value = values.get(name, None)
|
|
if value is not None or isinstance(column, columns.BaseContainerColumn):
|
|
value = column.to_python(value)
|
|
value_mngr = column.value_manager(self, column, value)
|
|
if name in values:
|
|
value_mngr.explicit = True
|
|
self._values[name] = value_mngr
|
|
|
|
# a flag set by the deserializer to indicate
|
|
# that update should be used when persisting changes
|
|
self._is_persisted = False
|
|
self._batch = None
|
|
self._timeout = NOT_SET
|
|
|
|
|
|
def __repr__(self):
|
|
"""
|
|
Pretty printing of models by their primary key
|
|
"""
|
|
return '{} <{}>'.format(self.__class__.__name__,
|
|
', '.join(('{}={}'.format(k, getattr(self, k)) for k,v in six.iteritems(self._primary_keys)))
|
|
)
|
|
|
|
|
|
|
|
@classmethod
|
|
def _discover_polymorphic_submodels(cls):
|
|
if not cls._is_polymorphic_base:
|
|
raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes')
|
|
def _discover(klass):
|
|
if not klass._is_polymorphic_base and klass.__polymorphic_key__ is not None:
|
|
cls._polymorphic_map[klass.__polymorphic_key__] = klass
|
|
for subklass in klass.__subclasses__():
|
|
_discover(subklass)
|
|
_discover(cls)
|
|
|
|
@classmethod
|
|
def _get_model_by_polymorphic_key(cls, key):
|
|
if not cls._is_polymorphic_base:
|
|
raise ModelException('_get_model_by_polymorphic_key can only be called on polymorphic base classes')
|
|
return cls._polymorphic_map.get(key)
|
|
|
|
@classmethod
|
|
def _construct_instance(cls, values):
|
|
"""
|
|
method used to construct instances from query results
|
|
this is where polymorphic deserialization occurs
|
|
"""
|
|
# we're going to take the values, which is from the DB as a dict
|
|
# and translate that into our local fields
|
|
# the db_map is a db_field -> model field map
|
|
items = values.items()
|
|
field_dict = dict([(cls._db_map.get(k, k),v) for k,v in items])
|
|
|
|
if cls._is_polymorphic:
|
|
poly_key = field_dict.get(cls._polymorphic_column_name)
|
|
|
|
if poly_key is None:
|
|
raise PolyMorphicModelException('polymorphic key was not found in values')
|
|
|
|
poly_base = cls if cls._is_polymorphic_base else cls._polymorphic_base
|
|
|
|
klass = poly_base._get_model_by_polymorphic_key(poly_key)
|
|
if klass is None:
|
|
poly_base._discover_polymorphic_submodels()
|
|
klass = poly_base._get_model_by_polymorphic_key(poly_key)
|
|
if klass is None:
|
|
raise PolyMorphicModelException(
|
|
'unrecognized polymorphic key {} for class {}'.format(poly_key, poly_base.__name__)
|
|
)
|
|
|
|
if not issubclass(klass, cls):
|
|
raise PolyMorphicModelException(
|
|
'{} is not a subclass of {}'.format(klass.__name__, cls.__name__)
|
|
)
|
|
|
|
field_dict = {k: v for k, v in field_dict.items() if k in klass._columns.keys()}
|
|
|
|
else:
|
|
klass = cls
|
|
|
|
instance = klass(**field_dict)
|
|
instance._is_persisted = True
|
|
return instance
|
|
|
|
def _can_update(self):
|
|
"""
|
|
Called by the save function to check if this should be
|
|
persisted with update or insert
|
|
|
|
:return:
|
|
"""
|
|
if not self._is_persisted: return False
|
|
pks = self._primary_keys.keys()
|
|
return all([not self._values[k].changed for k in self._primary_keys])
|
|
|
|
@classmethod
|
|
def _get_keyspace(cls):
|
|
""" Returns the manual keyspace, if set, otherwise the default keyspace """
|
|
return cls.__keyspace__ or DEFAULT_KEYSPACE
|
|
|
|
@classmethod
|
|
def _get_column(cls, name):
|
|
"""
|
|
Returns the column matching the given name, raising a key error if
|
|
it doesn't exist
|
|
|
|
:param name: the name of the column to return
|
|
:rtype: Column
|
|
"""
|
|
return cls._columns[name]
|
|
|
|
def __eq__(self, other):
|
|
if self.__class__ != other.__class__:
|
|
return False
|
|
|
|
# check attribute keys
|
|
keys = set(self._columns.keys())
|
|
other_keys = set(other._columns.keys())
|
|
if keys != other_keys:
|
|
return False
|
|
|
|
# check that all of the attributes match
|
|
for key in other_keys:
|
|
if getattr(self, key, None) != getattr(other, key, None):
|
|
return False
|
|
|
|
return True
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
@classmethod
|
|
def column_family_name(cls, include_keyspace=True):
|
|
"""
|
|
Returns the column family name if it's been defined
|
|
otherwise, it creates it from the module and class name
|
|
"""
|
|
cf_name = ''
|
|
if cls.__table_name__:
|
|
cf_name = cls.__table_name__.lower()
|
|
else:
|
|
# get polymorphic base table names if model is polymorphic
|
|
if cls._is_polymorphic and not cls._is_polymorphic_base:
|
|
return cls._polymorphic_base.column_family_name(include_keyspace=include_keyspace)
|
|
|
|
camelcase = re.compile(r'([a-z])([A-Z])')
|
|
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
|
|
|
|
cf_name += ccase(cls.__name__)
|
|
#trim to less than 48 characters or cassandra will complain
|
|
cf_name = cf_name[-48:]
|
|
cf_name = cf_name.lower()
|
|
cf_name = re.sub(r'^_+', '', cf_name)
|
|
if not include_keyspace: return cf_name
|
|
return '{}.{}'.format(cls._get_keyspace(), cf_name)
|
|
|
|
def validate(self):
|
|
""" Cleans and validates the field values """
|
|
for name, col in self._columns.items():
|
|
v = getattr(self, name)
|
|
if v is None and not self._values[name].explicit and col.has_default:
|
|
v = col.get_default()
|
|
val = col.validate(v)
|
|
setattr(self, name, val)
|
|
|
|
### Let an instance be used like a dict of its columns keys/values
|
|
|
|
def __iter__(self):
|
|
""" Iterate over column ids. """
|
|
for column_id in self._columns.keys():
|
|
yield column_id
|
|
|
|
def __getitem__(self, key):
|
|
""" Returns column's value. """
|
|
if not isinstance(key, six.string_types):
|
|
raise TypeError
|
|
if key not in self._columns.keys():
|
|
raise KeyError
|
|
return getattr(self, key)
|
|
|
|
def __setitem__(self, key, val):
|
|
""" Sets a column's value. """
|
|
if not isinstance(key, six.string_types):
|
|
raise TypeError
|
|
if key not in self._columns.keys():
|
|
raise KeyError
|
|
return setattr(self, key, val)
|
|
|
|
def __len__(self):
|
|
""" Returns the number of columns defined on that model. """
|
|
try:
|
|
return self._len
|
|
except:
|
|
self._len = len(self._columns.keys())
|
|
return self._len
|
|
|
|
def keys(self):
|
|
""" Returns list of column's IDs. """
|
|
return [k for k in self]
|
|
|
|
def values(self):
|
|
""" Returns list of column's values. """
|
|
return [self[k] for k in self]
|
|
|
|
def items(self):
|
|
""" Returns a list of columns's IDs/values. """
|
|
return [(k, self[k]) for k in self]
|
|
|
|
def _as_dict(self):
|
|
""" Returns a map of column names to cleaned values """
|
|
values = self._dynamic_columns or {}
|
|
for name, col in self._columns.items():
|
|
values[name] = col.to_database(getattr(self, name, None))
|
|
return values
|
|
|
|
@classmethod
|
|
def create(cls, **kwargs):
|
|
extra_columns = set(kwargs.keys()) - set(cls._columns.keys())
|
|
if extra_columns:
|
|
raise ValidationError("Incorrect columns passed: {}".format(extra_columns))
|
|
return cls.objects.create(**kwargs)
|
|
|
|
@classmethod
|
|
def all(cls):
|
|
return cls.objects.all()
|
|
|
|
@classmethod
|
|
def filter(cls, *args, **kwargs):
|
|
# if kwargs.values().count(None):
|
|
# raise CQLEngineException("Cannot pass None as a filter")
|
|
|
|
return cls.objects.filter(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def get(cls, *args, **kwargs):
|
|
return cls.objects.get(*args, **kwargs)
|
|
|
|
def timeout(self, timeout):
|
|
assert self._batch is None, 'Setting both timeout and batch is not supported'
|
|
self._timeout = timeout
|
|
return self
|
|
|
|
def save(self):
|
|
# handle polymorphic models
|
|
if self._is_polymorphic:
|
|
if self._is_polymorphic_base:
|
|
raise PolyMorphicModelException('cannot save polymorphic base model')
|
|
else:
|
|
setattr(self, self._polymorphic_column_name, self.__polymorphic_key__)
|
|
|
|
is_new = self.pk is None
|
|
self.validate()
|
|
self.__dmlquery__(self.__class__, self,
|
|
batch=self._batch,
|
|
ttl=self._ttl,
|
|
timestamp=self._timestamp,
|
|
consistency=self.__consistency__,
|
|
if_not_exists=self._if_not_exists,
|
|
transaction=self._transaction,
|
|
timeout=self._timeout).save()
|
|
|
|
#reset the value managers
|
|
for v in self._values.values():
|
|
v.reset_previous_value()
|
|
self._is_persisted = True
|
|
|
|
self._ttl = self.__default_ttl__
|
|
self._timestamp = None
|
|
|
|
return self
|
|
|
|
def update(self, **values):
|
|
for k, v in values.items():
|
|
col = self._columns.get(k)
|
|
|
|
# check for nonexistant columns
|
|
if col is None:
|
|
raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.__class__.__name__, k))
|
|
|
|
# check for primary key update attempts
|
|
if col.is_primary_key:
|
|
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(k, self.__module__, self.__class__.__name__))
|
|
|
|
setattr(self, k, v)
|
|
|
|
# handle polymorphic models
|
|
if self._is_polymorphic:
|
|
if self._is_polymorphic_base:
|
|
raise PolyMorphicModelException('cannot update polymorphic base model')
|
|
else:
|
|
setattr(self, self._polymorphic_column_name, self.__polymorphic_key__)
|
|
|
|
self.validate()
|
|
self.__dmlquery__(self.__class__, self,
|
|
batch=self._batch,
|
|
ttl=self._ttl,
|
|
timestamp=self._timestamp,
|
|
consistency=self.__consistency__,
|
|
transaction=self._transaction,
|
|
timeout=self._timeout).update()
|
|
|
|
#reset the value managers
|
|
for v in self._values.values():
|
|
v.reset_previous_value()
|
|
self._is_persisted = True
|
|
|
|
self._ttl = self.__default_ttl__
|
|
self._timestamp = None
|
|
|
|
return self
|
|
|
|
def delete(self):
|
|
""" Deletes this instance """
|
|
self.__dmlquery__(self.__class__, self,
|
|
batch=self._batch,
|
|
timestamp=self._timestamp,
|
|
consistency=self.__consistency__,
|
|
timeout=self._timeout).delete()
|
|
|
|
def get_changed_columns(self):
|
|
""" returns a list of the columns that have been updated since instantiation or save """
|
|
return [k for k,v in self._values.items() if v.changed]
|
|
|
|
@classmethod
|
|
def _class_batch(cls, batch):
|
|
return cls.objects.batch(batch)
|
|
|
|
def _inst_batch(self, batch):
|
|
assert self._timeout is NOT_SET, 'Setting both timeout and batch is not supported'
|
|
self._batch = batch
|
|
return self
|
|
|
|
|
|
batch = hybrid_classmethod(_class_batch, _inst_batch)
|
|
|
|
|
|
|
|
class ModelMetaClass(type):
|
|
|
|
def __new__(cls, name, bases, attrs):
|
|
"""
|
|
"""
|
|
#move column definitions into columns dict
|
|
#and set default column names
|
|
column_dict = OrderedDict()
|
|
primary_keys = OrderedDict()
|
|
pk_name = None
|
|
|
|
#get inherited properties
|
|
inherited_columns = OrderedDict()
|
|
for base in bases:
|
|
for k,v in getattr(base, '_defined_columns', {}).items():
|
|
inherited_columns.setdefault(k,v)
|
|
|
|
#short circuit __abstract__ inheritance
|
|
is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False)
|
|
|
|
#short circuit __polymorphic_key__ inheritance
|
|
attrs['__polymorphic_key__'] = attrs.get('__polymorphic_key__', None)
|
|
|
|
def _transform_column(col_name, col_obj):
|
|
column_dict[col_name] = col_obj
|
|
if col_obj.primary_key:
|
|
primary_keys[col_name] = col_obj
|
|
col_obj.set_column_name(col_name)
|
|
#set properties
|
|
attrs[col_name] = ColumnDescriptor(col_obj)
|
|
|
|
column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.Column)]
|
|
#column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position))
|
|
column_definitions = sorted(column_definitions, key=lambda x: x[1].position)
|
|
|
|
is_polymorphic_base = any([c[1].polymorphic_key for c in column_definitions])
|
|
|
|
column_definitions = [x for x in inherited_columns.items()] + column_definitions
|
|
polymorphic_columns = [c for c in column_definitions if c[1].polymorphic_key]
|
|
is_polymorphic = len(polymorphic_columns) > 0
|
|
if len(polymorphic_columns) > 1:
|
|
raise ModelDefinitionException('only one polymorphic_key can be defined in a model, {} found'.format(len(polymorphic_columns)))
|
|
|
|
polymorphic_column_name, polymorphic_column = polymorphic_columns[0] if polymorphic_columns else (None, None)
|
|
|
|
if isinstance(polymorphic_column, (columns.BaseContainerColumn, columns.Counter)):
|
|
raise ModelDefinitionException('counter and container columns cannot be used for polymorphic keys')
|
|
|
|
# find polymorphic base class
|
|
polymorphic_base = None
|
|
if is_polymorphic and not is_polymorphic_base:
|
|
def _get_polymorphic_base(bases):
|
|
for base in bases:
|
|
if getattr(base, '_is_polymorphic_base', False):
|
|
return base
|
|
klass = _get_polymorphic_base(base.__bases__)
|
|
if klass:
|
|
return klass
|
|
polymorphic_base = _get_polymorphic_base(bases)
|
|
|
|
defined_columns = OrderedDict(column_definitions)
|
|
|
|
# check for primary key
|
|
if not is_abstract and not any([v.primary_key for k,v in column_definitions]):
|
|
raise ModelDefinitionException("At least 1 primary key is required.")
|
|
|
|
counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)]
|
|
data_columns = [c for c in defined_columns.values() if not c.primary_key and not isinstance(c, columns.Counter)]
|
|
if counter_columns and data_columns:
|
|
raise ModelDefinitionException('counter models may not have data columns')
|
|
|
|
has_partition_keys = any(v.partition_key for (k, v) in column_definitions)
|
|
|
|
#transform column definitions
|
|
for k, v in column_definitions:
|
|
# don't allow a column with the same name as a built-in attribute or method
|
|
if k in BaseModel.__dict__:
|
|
raise ModelDefinitionException("column '{}' conflicts with built-in attribute/method".format(k))
|
|
|
|
# counter column primary keys are not allowed
|
|
if (v.primary_key or v.partition_key) and isinstance(v, (columns.Counter, columns.BaseContainerColumn)):
|
|
raise ModelDefinitionException('counter columns and container columns cannot be used as primary keys')
|
|
|
|
# this will mark the first primary key column as a partition
|
|
# key, if one hasn't been set already
|
|
if not has_partition_keys and v.primary_key:
|
|
v.partition_key = True
|
|
has_partition_keys = True
|
|
_transform_column(k, v)
|
|
|
|
partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key)
|
|
clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key)
|
|
|
|
#setup partition key shortcut
|
|
if len(partition_keys) == 0:
|
|
if not is_abstract:
|
|
raise ModelException("at least one partition key must be defined")
|
|
if len(partition_keys) == 1:
|
|
pk_name = [x for x in partition_keys.keys()][0]
|
|
attrs['pk'] = attrs[pk_name]
|
|
else:
|
|
# composite partition key case, get/set a tuple of values
|
|
_get = lambda self: tuple(self._values[c].getval() for c in partition_keys.keys())
|
|
_set = lambda self, val: tuple(self._values[c].setval(v) for (c, v) in zip(partition_keys.keys(), val))
|
|
attrs['pk'] = property(_get, _set)
|
|
|
|
# some validation
|
|
col_names = set()
|
|
for v in column_dict.values():
|
|
# check for duplicate column names
|
|
if v.db_field_name in col_names:
|
|
raise ModelException("{} defines the column {} more than once".format(name, v.db_field_name))
|
|
if v.clustering_order and not (v.primary_key and not v.partition_key):
|
|
raise ModelException("clustering_order may be specified only for clustering primary keys")
|
|
if v.clustering_order and v.clustering_order.lower() not in ('asc', 'desc'):
|
|
raise ModelException("invalid clustering order {} for column {}".format(repr(v.clustering_order), v.db_field_name))
|
|
col_names.add(v.db_field_name)
|
|
|
|
#create db_name -> model name map for loading
|
|
db_map = {}
|
|
for field_name, col in column_dict.items():
|
|
db_map[col.db_field_name] = field_name
|
|
|
|
#add management members to the class
|
|
attrs['_columns'] = column_dict
|
|
attrs['_primary_keys'] = primary_keys
|
|
attrs['_defined_columns'] = defined_columns
|
|
|
|
# maps the database field to the models key
|
|
attrs['_db_map'] = db_map
|
|
attrs['_pk_name'] = pk_name
|
|
attrs['_dynamic_columns'] = {}
|
|
|
|
attrs['_partition_keys'] = partition_keys
|
|
attrs['_clustering_keys'] = clustering_keys
|
|
attrs['_has_counter'] = len(counter_columns) > 0
|
|
|
|
# add polymorphic management attributes
|
|
attrs['_is_polymorphic_base'] = is_polymorphic_base
|
|
attrs['_is_polymorphic'] = is_polymorphic
|
|
attrs['_polymorphic_base'] = polymorphic_base
|
|
attrs['_polymorphic_column'] = polymorphic_column
|
|
attrs['_polymorphic_column_name'] = polymorphic_column_name
|
|
attrs['_polymorphic_map'] = {} if is_polymorphic_base else None
|
|
|
|
#setup class exceptions
|
|
DoesNotExistBase = None
|
|
for base in bases:
|
|
DoesNotExistBase = getattr(base, 'DoesNotExist', None)
|
|
if DoesNotExistBase is not None: break
|
|
DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist)
|
|
attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {})
|
|
|
|
MultipleObjectsReturnedBase = None
|
|
for base in bases:
|
|
MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None)
|
|
if MultipleObjectsReturnedBase is not None: break
|
|
MultipleObjectsReturnedBase = DoesNotExistBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned)
|
|
attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {})
|
|
|
|
#create the class and add a QuerySet to it
|
|
klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs)
|
|
|
|
return klass
|
|
|
|
|
|
import six
|
|
|
|
@six.add_metaclass(ModelMetaClass)
|
|
class Model(BaseModel):
|
|
"""
|
|
the db name for the column family can be set as the attribute db_name, or
|
|
it will be generated from the class name
|
|
"""
|
|
__abstract__ = True
|
|
# __metaclass__ = ModelMetaClass
|
|
|
|
|