Register udt namedtuple types in cassandra.cqltypes

This provides type visibility for pickling when deserializing
unregistered udts as keys to maps.
This commit is contained in:
Adam Holmberg
2015-01-22 17:59:04 -06:00
parent a2ed58e72d
commit ad920171d7

View File

@@ -31,14 +31,14 @@ from __future__ import absolute_import # to enable import io from stdlib
from binascii import unhexlify
import calendar
from collections import namedtuple
from datetime import datetime, timedelta
from decimal import Decimal
import io
import re
import socket
import time
from datetime import datetime, timedelta
import sys
from uuid import UUID
import warnings
import six
from six.moves import range
@@ -823,22 +823,23 @@ class UserType(TupleType):
typename = "'org.apache.cassandra.db.marshal.UserType'"
_cache = {}
_module = sys.modules[__name__]
@classmethod
def make_udt_class(cls, keyspace, udt_name, names_and_types, mapped_class):
if six.PY2 and isinstance(udt_name, unicode):
udt_name = udt_name.encode('utf-8')
try:
return cls._cache[(keyspace, udt_name)]
except KeyError:
fieldnames, types = zip(*names_and_types)
field_names, types = zip(*names_and_types)
instance = type(udt_name, (cls,), {'subtypes': types,
'cassname': cls.cassname,
'typename': udt_name,
'fieldnames': fieldnames,
'fieldnames': field_names,
'keyspace': keyspace,
'mapped_class': mapped_class})
'mapped_class': mapped_class,
'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)})
cls._cache[(keyspace, udt_name)] = instance
return instance
@@ -853,7 +854,8 @@ class UserType(TupleType):
'typename': udt_name,
'fieldnames': field_names,
'keyspace': keyspace,
'mapped_class': None})
'mapped_class': None,
'tuple_type': namedtuple(udt_name, field_names)})
@classmethod
def cql_parameterized_type(cls):
@@ -885,8 +887,7 @@ class UserType(TupleType):
if cls.mapped_class:
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
else:
Result = namedtuple(cls.typename, cls.fieldnames)
return Result(*values)
return cls.tuple_type(*values)
@classmethod
def serialize_safe(cls, val, protocol_version):
@@ -902,6 +903,18 @@ class UserType(TupleType):
buf.write(int32_pack(-1))
return buf.getvalue()
@classmethod
def _make_registered_udt_namedtuple(cls, keyspace, name, field_names):
# this is required to make the type resolvable via this module...
# required when unregistered udts are pickled for use as keys in
# util.OrderedMap
qualified_name = "%s_%s" % (keyspace, name)
nt = getattr(cls._module, qualified_name, None)
if not nt:
nt = namedtuple(qualified_name, field_names)
setattr(cls._module, qualified_name, nt)
return nt
class CompositeType(_ParameterizedType):
typename = "'org.apache.cassandra.db.marshal.CompositeType'"