216 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			216 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import re
 | |
| import six
 | |
| 
 | |
| from cassandra.util import OrderedDict
 | |
| from cassandra.cqlengine import CQLEngineException
 | |
| from cassandra.cqlengine import columns
 | |
| from cassandra.cqlengine import connection
 | |
| from cassandra.cqlengine import models
 | |
| 
 | |
| 
 | |
| class UserTypeException(CQLEngineException):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class UserTypeDefinitionException(UserTypeException):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class BaseUserType(object):
 | |
|     """
 | |
|     The base type class; don't inherit from this, inherit from UserType, defined below
 | |
|     """
 | |
|     __type_name__ = None
 | |
| 
 | |
|     _fields = None
 | |
|     _db_map = None
 | |
| 
 | |
|     def __init__(self, **values):
 | |
|         self._values = {}
 | |
|         if self._db_map:
 | |
|             values = dict((self._db_map.get(k, k), v) for k, v in values.items())
 | |
| 
 | |
|         for name, field in self._fields.items():
 | |
|             field_default = field.get_default() if field.has_default else None
 | |
|             value = values.get(name, field_default)
 | |
|             if value is not None or isinstance(field, columns.BaseContainerColumn):
 | |
|                 value = field.to_python(value)
 | |
|             value_mngr = field.value_manager(self, field, value)
 | |
|             value_mngr.explicit = name in values
 | |
|             self._values[name] = value_mngr
 | |
| 
 | |
|     def __eq__(self, other):
 | |
|         if self.__class__ != other.__class__:
 | |
|             return False
 | |
| 
 | |
|         keys = set(self._fields.keys())
 | |
|         other_keys = set(other._fields.keys())
 | |
|         if keys != other_keys:
 | |
|             return False
 | |
| 
 | |
|         for key in other_keys:
 | |
|             if getattr(self, key, None) != getattr(other, key, None):
 | |
|                 return False
 | |
| 
 | |
|         return True
 | |
| 
 | |
|     def __ne__(self, other):
 | |
|         return not self.__eq__(other)
 | |
| 
 | |
|     def __str__(self):
 | |
|         return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in six.iteritems(self._values)))
 | |
| 
 | |
|     def has_changed_fields(self):
 | |
|         return any(v.changed for v in self._values.values())
 | |
| 
 | |
|     def reset_changed_fields(self):
 | |
|         for v in self._values.values():
 | |
|             v.reset_previous_value()
 | |
| 
 | |
|     def __iter__(self):
 | |
|         for field in self._fields.keys():
 | |
|             yield field
 | |
| 
 | |
|     def __getattr__(self, attr):
 | |
|         # provides the mapping from db_field to fields
 | |
|         try:
 | |
|             return getattr(self, self._db_map[attr])
 | |
|         except KeyError:
 | |
|             raise AttributeError(attr)
 | |
| 
 | |
|     def __getitem__(self, key):
 | |
|         if not isinstance(key, six.string_types):
 | |
|             raise TypeError
 | |
|         if key not in self._fields.keys():
 | |
|             raise KeyError
 | |
|         return getattr(self, key)
 | |
| 
 | |
|     def __setitem__(self, key, val):
 | |
|         if not isinstance(key, six.string_types):
 | |
|             raise TypeError
 | |
|         if key not in self._fields.keys():
 | |
|             raise KeyError
 | |
|         return setattr(self, key, val)
 | |
| 
 | |
|     def __len__(self):
 | |
|         try:
 | |
|             return self._len
 | |
|         except:
 | |
|             self._len = len(self._fields.keys())
 | |
|             return self._len
 | |
| 
 | |
|     def keys(self):
 | |
|         """ Returns a list of column IDs. """
 | |
|         return [k for k in self]
 | |
| 
 | |
|     def values(self):
 | |
|         """ Returns list of column values. """
 | |
|         return [self[k] for k in self]
 | |
| 
 | |
|     def items(self):
 | |
|         """ Returns a list of column ID/value tuples. """
 | |
|         return [(k, self[k]) for k in self]
 | |
| 
 | |
|     @classmethod
 | |
|     def register_for_keyspace(cls, keyspace):
 | |
|         connection.register_udt(keyspace, cls.type_name(), cls)
 | |
| 
 | |
|     @classmethod
 | |
|     def type_name(cls):
 | |
|         """
 | |
|         Returns the type name if it's been defined
 | |
|         otherwise, it creates it from the class name
 | |
|         """
 | |
|         if cls.__type_name__:
 | |
|             type_name = cls.__type_name__.lower()
 | |
|         else:
 | |
|             camelcase = re.compile(r'([a-z])([A-Z])')
 | |
|             ccase = lambda s: camelcase.sub(lambda v: '{0}_{1}'.format(v.group(1), v.group(2)), s)
 | |
| 
 | |
|             type_name = ccase(cls.__name__)
 | |
|             # trim to less than 48 characters or cassandra will complain
 | |
|             type_name = type_name[-48:]
 | |
|             type_name = type_name.lower()
 | |
|             type_name = re.sub(r'^_+', '', type_name)
 | |
|             cls.__type_name__ = type_name
 | |
| 
 | |
|         return type_name
 | |
| 
 | |
|     def validate(self):
 | |
|         """
 | |
|         Cleans and validates the field values
 | |
|         """
 | |
|         for name, field in self._fields.items():
 | |
|             v = getattr(self, name)
 | |
|             if v is None and not self._values[name].explicit and field.has_default:
 | |
|                 v = field.get_default()
 | |
|             val = field.validate(v)
 | |
|             setattr(self, name, val)
 | |
| 
 | |
| 
 | |
| class UserTypeMetaClass(type):
 | |
| 
 | |
|     def __new__(cls, name, bases, attrs):
 | |
|         field_dict = OrderedDict()
 | |
| 
 | |
|         field_defs = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)]
 | |
|         field_defs = sorted(field_defs, key=lambda x: x[1].position)
 | |
| 
 | |
|         def _transform_column(field_name, field_obj):
 | |
|             field_dict[field_name] = field_obj
 | |
|             field_obj.set_column_name(field_name)
 | |
|             attrs[field_name] = models.ColumnDescriptor(field_obj)
 | |
| 
 | |
|         # transform field definitions
 | |
|         for k, v in field_defs:
 | |
|             # don't allow a field with the same name as a built-in attribute or method
 | |
|             if k in BaseUserType.__dict__:
 | |
|                 raise UserTypeDefinitionException("field '{0}' conflicts with built-in attribute/method".format(k))
 | |
|             _transform_column(k, v)
 | |
| 
 | |
|         attrs['_fields'] = field_dict
 | |
| 
 | |
|         db_map = {}
 | |
|         for field_name, field in field_dict.items():
 | |
|             db_field = field.db_field_name
 | |
|             if db_field != field_name:
 | |
|                 if db_field in field_dict:
 | |
|                     raise UserTypeDefinitionException("db_field '{0}' for field '{1}' conflicts with another attribute name".format(db_field, field_name))
 | |
|                 db_map[db_field] = field_name
 | |
|         attrs['_db_map'] = db_map
 | |
| 
 | |
|         klass = super(UserTypeMetaClass, cls).__new__(cls, name, bases, attrs)
 | |
| 
 | |
|         return klass
 | |
| 
 | |
| 
 | |
| @six.add_metaclass(UserTypeMetaClass)
 | |
| class UserType(BaseUserType):
 | |
|     """
 | |
|     This class is used to model User Defined Types. To define a type, declare a class inheriting from this,
 | |
|     and assign field types as class attributes:
 | |
| 
 | |
|     .. code-block:: python
 | |
| 
 | |
|         # connect with default keyspace ...
 | |
| 
 | |
|         from cassandra.cqlengine.columns import Text, Integer
 | |
|         from cassandra.cqlengine.usertype import UserType
 | |
| 
 | |
|         class address(UserType):
 | |
|             street = Text()
 | |
|             zipcode = Integer()
 | |
| 
 | |
|         from cassandra.cqlengine import management
 | |
|         management.sync_type(address)
 | |
| 
 | |
|     Please see :ref:`user_types` for a complete example and discussion.
 | |
|     """
 | |
| 
 | |
|     __type_name__ = None
 | |
|     """
 | |
|     *Optional.* Sets the name of the CQL type for this type.
 | |
| 
 | |
|     If not specified, the type name will be the name of the class, with it's module name as it's prefix.
 | |
|     """
 | 
