216 lines
6.7 KiB
Python
216 lines
6.7 KiB
Python
import re
|
|
import six
|
|
|
|
from cassandra.util import OrderedDict
|
|
from cassandra.cqlengine import CQLEngineException
|
|
from cassandra.cqlengine import columns
|
|
from cassandra.cqlengine import connection as conn
|
|
from cassandra.cqlengine import models
|
|
|
|
|
|
class UserTypeException(CQLEngineException):
|
|
pass
|
|
|
|
|
|
class UserTypeDefinitionException(UserTypeException):
|
|
pass
|
|
|
|
|
|
class BaseUserType(object):
|
|
"""
|
|
The base type class; don't inherit from this, inherit from UserType, defined below
|
|
"""
|
|
__type_name__ = None
|
|
|
|
_fields = None
|
|
_db_map = None
|
|
|
|
def __init__(self, **values):
|
|
self._values = {}
|
|
if self._db_map:
|
|
values = dict((self._db_map.get(k, k), v) for k, v in values.items())
|
|
|
|
for name, field in self._fields.items():
|
|
field_default = field.get_default() if field.has_default else None
|
|
value = values.get(name, field_default)
|
|
if value is not None or isinstance(field, columns.BaseContainerColumn):
|
|
value = field.to_python(value)
|
|
value_mngr = field.value_manager(self, field, value)
|
|
value_mngr.explicit = name in values
|
|
self._values[name] = value_mngr
|
|
|
|
def __eq__(self, other):
|
|
if self.__class__ != other.__class__:
|
|
return False
|
|
|
|
keys = set(self._fields.keys())
|
|
other_keys = set(other._fields.keys())
|
|
if keys != other_keys:
|
|
return False
|
|
|
|
for key in other_keys:
|
|
if getattr(self, key, None) != getattr(other, key, None):
|
|
return False
|
|
|
|
return True
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __str__(self):
|
|
return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in six.iteritems(self._values)))
|
|
|
|
def has_changed_fields(self):
|
|
return any(v.changed for v in self._values.values())
|
|
|
|
def reset_changed_fields(self):
|
|
for v in self._values.values():
|
|
v.reset_previous_value()
|
|
|
|
def __iter__(self):
|
|
for field in self._fields.keys():
|
|
yield field
|
|
|
|
def __getattr__(self, attr):
|
|
# provides the mapping from db_field to fields
|
|
try:
|
|
return getattr(self, self._db_map[attr])
|
|
except KeyError:
|
|
raise AttributeError(attr)
|
|
|
|
def __getitem__(self, key):
|
|
if not isinstance(key, six.string_types):
|
|
raise TypeError
|
|
if key not in self._fields.keys():
|
|
raise KeyError
|
|
return getattr(self, key)
|
|
|
|
def __setitem__(self, key, val):
|
|
if not isinstance(key, six.string_types):
|
|
raise TypeError
|
|
if key not in self._fields.keys():
|
|
raise KeyError
|
|
return setattr(self, key, val)
|
|
|
|
def __len__(self):
|
|
try:
|
|
return self._len
|
|
except:
|
|
self._len = len(self._fields.keys())
|
|
return self._len
|
|
|
|
def keys(self):
|
|
""" Returns a list of column IDs. """
|
|
return [k for k in self]
|
|
|
|
def values(self):
|
|
""" Returns list of column values. """
|
|
return [self[k] for k in self]
|
|
|
|
def items(self):
|
|
""" Returns a list of column ID/value tuples. """
|
|
return [(k, self[k]) for k in self]
|
|
|
|
@classmethod
|
|
def register_for_keyspace(cls, keyspace, connection=None):
|
|
conn.register_udt(keyspace, cls.type_name(), cls, connection=connection)
|
|
|
|
@classmethod
|
|
def type_name(cls):
|
|
"""
|
|
Returns the type name if it's been defined
|
|
otherwise, it creates it from the class name
|
|
"""
|
|
if cls.__type_name__:
|
|
type_name = cls.__type_name__.lower()
|
|
else:
|
|
camelcase = re.compile(r'([a-z])([A-Z])')
|
|
ccase = lambda s: camelcase.sub(lambda v: '{0}_{1}'.format(v.group(1), v.group(2)), s)
|
|
|
|
type_name = ccase(cls.__name__)
|
|
# trim to less than 48 characters or cassandra will complain
|
|
type_name = type_name[-48:]
|
|
type_name = type_name.lower()
|
|
type_name = re.sub(r'^_+', '', type_name)
|
|
cls.__type_name__ = type_name
|
|
|
|
return type_name
|
|
|
|
def validate(self):
|
|
"""
|
|
Cleans and validates the field values
|
|
"""
|
|
for name, field in self._fields.items():
|
|
v = getattr(self, name)
|
|
if v is None and not self._values[name].explicit and field.has_default:
|
|
v = field.get_default()
|
|
val = field.validate(v)
|
|
setattr(self, name, val)
|
|
|
|
|
|
class UserTypeMetaClass(type):
|
|
|
|
def __new__(cls, name, bases, attrs):
|
|
field_dict = OrderedDict()
|
|
|
|
field_defs = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)]
|
|
field_defs = sorted(field_defs, key=lambda x: x[1].position)
|
|
|
|
def _transform_column(field_name, field_obj):
|
|
field_dict[field_name] = field_obj
|
|
field_obj.set_column_name(field_name)
|
|
attrs[field_name] = models.ColumnDescriptor(field_obj)
|
|
|
|
# transform field definitions
|
|
for k, v in field_defs:
|
|
# don't allow a field with the same name as a built-in attribute or method
|
|
if k in BaseUserType.__dict__:
|
|
raise UserTypeDefinitionException("field '{0}' conflicts with built-in attribute/method".format(k))
|
|
_transform_column(k, v)
|
|
|
|
attrs['_fields'] = field_dict
|
|
|
|
db_map = {}
|
|
for field_name, field in field_dict.items():
|
|
db_field = field.db_field_name
|
|
if db_field != field_name:
|
|
if db_field in field_dict:
|
|
raise UserTypeDefinitionException("db_field '{0}' for field '{1}' conflicts with another attribute name".format(db_field, field_name))
|
|
db_map[db_field] = field_name
|
|
attrs['_db_map'] = db_map
|
|
|
|
klass = super(UserTypeMetaClass, cls).__new__(cls, name, bases, attrs)
|
|
|
|
return klass
|
|
|
|
|
|
@six.add_metaclass(UserTypeMetaClass)
|
|
class UserType(BaseUserType):
|
|
"""
|
|
This class is used to model User Defined Types. To define a type, declare a class inheriting from this,
|
|
and assign field types as class attributes:
|
|
|
|
.. code-block:: python
|
|
|
|
# connect with default keyspace ...
|
|
|
|
from cassandra.cqlengine.columns import Text, Integer
|
|
from cassandra.cqlengine.usertype import UserType
|
|
|
|
class address(UserType):
|
|
street = Text()
|
|
zipcode = Integer()
|
|
|
|
from cassandra.cqlengine import management
|
|
management.sync_type(address)
|
|
|
|
Please see :ref:`user_types` for a complete example and discussion.
|
|
"""
|
|
|
|
__type_name__ = None
|
|
"""
|
|
*Optional.* Sets the name of the CQL type for this type.
|
|
|
|
If not specified, the type name will be the name of the class, with it's module name as it's prefix.
|
|
"""
|