Register udt namedtuple types in cassandra.cqltypes
This provides type visibility for pickling when deserializing unregistered udts as keys to maps.
This commit is contained in:
@@ -31,14 +31,14 @@ from __future__ import absolute_import # to enable import io from stdlib
|
||||
from binascii import unhexlify
|
||||
import calendar
|
||||
from collections import namedtuple
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
import io
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
import sys
|
||||
from uuid import UUID
|
||||
import warnings
|
||||
|
||||
import six
|
||||
from six.moves import range
|
||||
@@ -823,22 +823,23 @@ class UserType(TupleType):
|
||||
typename = "'org.apache.cassandra.db.marshal.UserType'"
|
||||
|
||||
_cache = {}
|
||||
_module = sys.modules[__name__]
|
||||
|
||||
@classmethod
|
||||
def make_udt_class(cls, keyspace, udt_name, names_and_types, mapped_class):
|
||||
if six.PY2 and isinstance(udt_name, unicode):
|
||||
udt_name = udt_name.encode('utf-8')
|
||||
|
||||
try:
|
||||
return cls._cache[(keyspace, udt_name)]
|
||||
except KeyError:
|
||||
fieldnames, types = zip(*names_and_types)
|
||||
field_names, types = zip(*names_and_types)
|
||||
instance = type(udt_name, (cls,), {'subtypes': types,
|
||||
'cassname': cls.cassname,
|
||||
'typename': udt_name,
|
||||
'fieldnames': fieldnames,
|
||||
'fieldnames': field_names,
|
||||
'keyspace': keyspace,
|
||||
'mapped_class': mapped_class})
|
||||
'mapped_class': mapped_class,
|
||||
'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)})
|
||||
cls._cache[(keyspace, udt_name)] = instance
|
||||
return instance
|
||||
|
||||
@@ -853,7 +854,8 @@ class UserType(TupleType):
|
||||
'typename': udt_name,
|
||||
'fieldnames': field_names,
|
||||
'keyspace': keyspace,
|
||||
'mapped_class': None})
|
||||
'mapped_class': None,
|
||||
'tuple_type': namedtuple(udt_name, field_names)})
|
||||
|
||||
@classmethod
|
||||
def cql_parameterized_type(cls):
|
||||
@@ -885,8 +887,7 @@ class UserType(TupleType):
|
||||
if cls.mapped_class:
|
||||
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
|
||||
else:
|
||||
Result = namedtuple(cls.typename, cls.fieldnames)
|
||||
return Result(*values)
|
||||
return cls.tuple_type(*values)
|
||||
|
||||
@classmethod
|
||||
def serialize_safe(cls, val, protocol_version):
|
||||
@@ -902,6 +903,18 @@ class UserType(TupleType):
|
||||
buf.write(int32_pack(-1))
|
||||
return buf.getvalue()
|
||||
|
||||
@classmethod
|
||||
def _make_registered_udt_namedtuple(cls, keyspace, name, field_names):
|
||||
# this is required to make the type resolvable via this module...
|
||||
# required when unregistered udts are pickled for use as keys in
|
||||
# util.OrderedMap
|
||||
qualified_name = "%s_%s" % (keyspace, name)
|
||||
nt = getattr(cls._module, qualified_name, None)
|
||||
if not nt:
|
||||
nt = namedtuple(qualified_name, field_names)
|
||||
setattr(cls._module, qualified_name, nt)
|
||||
return nt
|
||||
|
||||
|
||||
class CompositeType(_ParameterizedType):
|
||||
typename = "'org.apache.cassandra.db.marshal.CompositeType'"
|
||||
|
||||
Reference in New Issue
Block a user