Register udt namedtuple types in cassandra.cqltypes

This provides type visibility for pickling when deserializing
unregistered udts as keys to maps.
This commit is contained in:
Adam Holmberg
2015-01-22 17:59:04 -06:00
parent a2ed58e72d
commit ad920171d7

View File

@@ -31,14 +31,14 @@ from __future__ import absolute_import # to enable import io from stdlib
from binascii import unhexlify from binascii import unhexlify
import calendar import calendar
from collections import namedtuple from collections import namedtuple
from datetime import datetime, timedelta
from decimal import Decimal from decimal import Decimal
import io import io
import re import re
import socket import socket
import time import time
from datetime import datetime, timedelta import sys
from uuid import UUID from uuid import UUID
import warnings
import six import six
from six.moves import range from six.moves import range
@@ -823,22 +823,23 @@ class UserType(TupleType):
typename = "'org.apache.cassandra.db.marshal.UserType'" typename = "'org.apache.cassandra.db.marshal.UserType'"
_cache = {} _cache = {}
_module = sys.modules[__name__]
@classmethod @classmethod
def make_udt_class(cls, keyspace, udt_name, names_and_types, mapped_class): def make_udt_class(cls, keyspace, udt_name, names_and_types, mapped_class):
if six.PY2 and isinstance(udt_name, unicode): if six.PY2 and isinstance(udt_name, unicode):
udt_name = udt_name.encode('utf-8') udt_name = udt_name.encode('utf-8')
try: try:
return cls._cache[(keyspace, udt_name)] return cls._cache[(keyspace, udt_name)]
except KeyError: except KeyError:
fieldnames, types = zip(*names_and_types) field_names, types = zip(*names_and_types)
instance = type(udt_name, (cls,), {'subtypes': types, instance = type(udt_name, (cls,), {'subtypes': types,
'cassname': cls.cassname, 'cassname': cls.cassname,
'typename': udt_name, 'typename': udt_name,
'fieldnames': fieldnames, 'fieldnames': field_names,
'keyspace': keyspace, 'keyspace': keyspace,
'mapped_class': mapped_class}) 'mapped_class': mapped_class,
'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)})
cls._cache[(keyspace, udt_name)] = instance cls._cache[(keyspace, udt_name)] = instance
return instance return instance
@@ -853,7 +854,8 @@ class UserType(TupleType):
'typename': udt_name, 'typename': udt_name,
'fieldnames': field_names, 'fieldnames': field_names,
'keyspace': keyspace, 'keyspace': keyspace,
'mapped_class': None}) 'mapped_class': None,
'tuple_type': namedtuple(udt_name, field_names)})
@classmethod @classmethod
def cql_parameterized_type(cls): def cql_parameterized_type(cls):
@@ -885,8 +887,7 @@ class UserType(TupleType):
if cls.mapped_class: if cls.mapped_class:
return cls.mapped_class(**dict(zip(cls.fieldnames, values))) return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
else: else:
Result = namedtuple(cls.typename, cls.fieldnames) return cls.tuple_type(*values)
return Result(*values)
@classmethod @classmethod
def serialize_safe(cls, val, protocol_version): def serialize_safe(cls, val, protocol_version):
@@ -902,6 +903,18 @@ class UserType(TupleType):
buf.write(int32_pack(-1)) buf.write(int32_pack(-1))
return buf.getvalue() return buf.getvalue()
@classmethod
def _make_registered_udt_namedtuple(cls, keyspace, name, field_names):
# this is required to make the type resolvable via this module...
# required when unregistered udts are pickled for use as keys in
# util.OrderedMap
qualified_name = "%s_%s" % (keyspace, name)
nt = getattr(cls._module, qualified_name, None)
if not nt:
nt = namedtuple(qualified_name, field_names)
setattr(cls._module, qualified_name, nt)
return nt
class CompositeType(_ParameterizedType): class CompositeType(_ParameterizedType):
typename = "'org.apache.cassandra.db.marshal.CompositeType'" typename = "'org.apache.cassandra.db.marshal.CompositeType'"