Use registered UDTs for non-prepared encoding

This commit is contained in:
Tyler Hobbs
2014-06-18 16:29:05 -05:00
parent 7a838e2350
commit 8480123f40
4 changed files with 54 additions and 29 deletions

View File

@@ -44,6 +44,7 @@ from itertools import groupby
from cassandra import (ConsistencyLevel, AuthenticationFailed,
OperationTimedOut, UnsupportedOperation)
from cassandra.connection import ConnectionException, ConnectionShutdown
from cassandra.encoder import cql_encode_all_types, cql_encoders
from cassandra.protocol import (QueryMessage, ResultMessage,
ErrorMessage, ReadTimeoutErrorMessage,
WriteTimeoutErrorMessage,
@@ -409,8 +410,7 @@ class Cluster(object):
self._listener_lock = Lock()
# let Session objects be GC'ed (and shutdown) when the user no longer
# holds a reference. Normally the cycle detector would handle this,
# but implementing __del__ prevents that.
# holds a reference.
self.sessions = WeakSet()
self.metadata = Metadata(self)
self.control_connection = None
@@ -451,8 +451,10 @@ class Cluster(object):
self.control_connection = ControlConnection(
self, self.control_connection_timeout)
def register_type_class(self, keyspace, user_type, klass):
def register_user_type(self, keyspace, user_type, klass):
self._user_types[keyspace][user_type] = klass
for session in self.sessions:
self.session.user_type_registered(keyspace, user_type, klass)
def get_min_requests_per_connection(self, host_distance):
return self._min_requests_per_connection[host_distance]
@@ -602,6 +604,9 @@ class Cluster(object):
def _new_session(self):
session = Session(self, self.metadata.all_hosts())
for keyspace, type_map in six.iteritems(self._user_types):
for udt_name, klass in six.iteritems(type_map):
session.user_type_registered(keyspace, udt_name, klass)
self.sessions.add(session)
return session
@@ -1064,6 +1069,19 @@ class Session(object):
_metrics = None
_protocol_version = None
encoders = None
def user_type_registered(self, keyspace, user_type, klass):
type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type]
def encode(val):
return '{ %s }' % ' , '.join('%s : %s' % (
field_name,
cql_encode_all_types(getattr(val, field_name))
) for field_name in type_meta.field_names)
self._encoders[klass] = encode
def __init__(self, cluster, hosts):
self.cluster = cluster
self.hosts = hosts
@@ -1074,6 +1092,8 @@ class Session(object):
self._metrics = cluster.metrics
self._protocol_version = self.cluster.protocol_version
self._encoders = cql_encoders.copy()
# create connection pools in parallel
futures = []
for host in hosts:
@@ -1196,7 +1216,7 @@ class Session(object):
if isinstance(query, SimpleStatement):
query_string = query.query_string
if parameters:
query_string = bind_params(query.query_string, parameters)
query_string = bind_params(query.query_string, parameters, self._encoders)
message = QueryMessage(
query_string, cl, query.serial_consistency_level,
fetch_size, timestamp=timestamp)
@@ -1701,8 +1721,8 @@ class ControlConnection(object):
cf_query, col_query)
log.debug("[control connection] Fetched table info for %s.%s, rebuilding metadata", (keyspace, table))
cf_result = dict_factory(*cf_result.results)
col_result = dict_factory(*col_result.results)
cf_result = dict_factory(*cf_result.results) if cf_result else {}
col_result = dict_factory(*col_result.results) if col_result else {}
self._cluster.metadata.table_changed(keyspace, table, cf_result, col_result)
elif usertype:
# user defined types within this keyspace changed
@@ -1710,7 +1730,7 @@ class ControlConnection(object):
types_query = QueryMessage(query=self._SELECT_USERTYPES + where_clause, consistency_level=cl)
types_result = connection.wait_for_response(types_query)
log.debug("[control connection] Fetched user type info for %s.%s, rebuilding metadata", (keyspace, usertype))
types_result = dict_factory(*types_result)
types_result = dict_factory(*types_result.results) if types_result.results else {}
self._cluster.metadata.usertype_changed(keyspace, usertype, types_result)
elif keyspace:
# only the keyspace itself changed (such as replication settings)
@@ -1718,7 +1738,7 @@ class ControlConnection(object):
ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl)
ks_result = connection.wait_for_response(ks_query)
log.debug("[control connection] Fetched keyspace info for %s, rebuilding metadata", (keyspace,))
ks_result = dict_factory(*types_result)
ks_result = dict_factory(*ks_result.results) if ks_result.results else {}
self._cluster.metadata.keyspace_changed(keyspace, ks_result)
else:
# build everything from scratch
@@ -1730,12 +1750,12 @@ class ControlConnection(object):
if self._protocol_version >= 3:
queries.append(QueryMessage(query=self._SELECT_USERTYPES, consistency_level=cl))
ks_result, cf_result, col_result, types_result = connection.wait_for_responses(*queries)
types_result = dict_factory(*types_result)
types_result = dict_factory(*types_result.results) if types_result.results else {}
else:
ks_result, cf_result, col_result = connection.wait_for_responses(*queries)
types_result = {}
ks_result = dict_factory(*types_result)
ks_result = dict_factory(*ks_result.results)
cf_result = dict_factory(*cf_result.results)
col_result = dict_factory(*col_result.results)

View File

@@ -560,14 +560,14 @@ class ResultMessage(_MessageType):
ksname = read_string(f)
results = ksname
elif kind == RESULT_KIND_PREPARED:
results = cls.recv_results_prepared(f)
results = cls.recv_results_prepared(f, user_type_map)
elif kind == RESULT_KIND_SCHEMA_CHANGE:
results = cls.recv_results_schema_change(f, protocol_version)
return cls(kind, results, paging_state)
@classmethod
def recv_results_rows(cls, f, protocol_version, user_type_map):
paging_state, column_metadata = cls.recv_results_metadata(f)
paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map)
rowcount = read_int(f)
rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
colnames = [c[2] for c in column_metadata]
@@ -579,9 +579,9 @@ class ResultMessage(_MessageType):
return (paging_state, (colnames, parsed_rows))
@classmethod
def recv_results_prepared(cls, f):
def recv_results_prepared(cls, f, user_type_map):
query_id = read_binary_string(f)
_, column_metadata = cls.recv_results_metadata(f)
_, column_metadata = cls.recv_results_metadata(f, user_type_map)
return (query_id, column_metadata)
@classmethod

View File

@@ -567,9 +567,10 @@ class BatchStatement(Statement):
"""
_statements_and_parameters = None
_session = None
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
consistency_level=None):
consistency_level=None, session=None):
"""
`batch_type` specifies The :class:`.BatchType` for the batch operation.
Defaults to :attr:`.BatchType.LOGGED`.
@@ -605,6 +606,7 @@ class BatchStatement(Statement):
"""
self.batch_type = batch_type
self._statements_and_parameters = []
self._session = session
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level)
def add(self, statement, parameters=None):
@@ -617,7 +619,8 @@ class BatchStatement(Statement):
"""
if isinstance(statement, six.string_types):
if parameters:
statement = bind_params(statement, parameters)
encoders = cql_encoders if self._session is None else self._session.encoders
statement = bind_params(statement, parameters, encoders)
self._statements_and_parameters.append((False, statement, ()))
elif isinstance(statement, PreparedStatement):
query_id = statement.query_id
@@ -635,7 +638,8 @@ class BatchStatement(Statement):
# it must be a SimpleStatement
query_string = statement.query_string
if parameters:
query_string = bind_params(query_string, parameters)
encoders = cql_encoders if self._session is None else self._session.encoders
query_string = bind_params(query_string, parameters, encoders)
self._statements_and_parameters.append((False, query_string, ()))
return self
@@ -677,11 +681,11 @@ class ValueSequence(object):
return cql_encode_sequence(self.sequence)
def bind_params(query, params):
def bind_params(query, params, encoders):
if isinstance(params, dict):
return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
return query % dict((k, encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
else:
return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v) for v in params)
return query % tuple(encoders.get(type(v), cql_encode_object)(v) for v in params)
class TraceUnavailable(Exception):

View File

@@ -17,6 +17,7 @@ try:
except ImportError:
import unittest # noqa
from cassandra.encoder import cql_encoders
from cassandra.query import bind_params, ValueSequence
from cassandra.query import PreparedStatement, BoundStatement
from cassandra.cqltypes import Int32Type
@@ -28,31 +29,31 @@ from six.moves import xrange
class ParamBindingTest(unittest.TestCase):
def test_bind_sequence(self):
result = bind_params("%s %s %s", (1, "a", 2.0))
result = bind_params("%s %s %s", (1, "a", 2.0), cql_encoders)
self.assertEqual(result, "1 'a' 2.0")
def test_bind_map(self):
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0))
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), cql_encoders)
self.assertEqual(result, "1 'a' 2.0")
def test_sequence_param(self):
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),))
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), cql_encoders)
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
def test_generator_param(self):
result = bind_params("%s", ((i for i in xrange(3)),))
result = bind_params("%s", ((i for i in xrange(3)),), cql_encoders)
self.assertEqual(result, "[ 0 , 1 , 2 ]")
def test_none_param(self):
result = bind_params("%s", (None,))
result = bind_params("%s", (None,), cql_encoders)
self.assertEqual(result, "NULL")
def test_list_collection(self):
result = bind_params("%s", (['a', 'b', 'c'],))
result = bind_params("%s", (['a', 'b', 'c'],), cql_encoders)
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
def test_set_collection(self):
result = bind_params("%s", (set(['a', 'b']),))
result = bind_params("%s", (set(['a', 'b']),), cql_encoders)
self.assertIn(result, ("{ 'a' , 'b' }", "{ 'b' , 'a' }"))
def test_map_collection(self):
@@ -60,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
vals['a'] = 'a'
vals['b'] = 'b'
vals['c'] = 'c'
result = bind_params("%s", (vals,))
result = bind_params("%s", (vals,), cql_encoders)
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
def test_quote_escaping(self):
result = bind_params("%s", ("""'ef''ef"ef""ef'""",))
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), cql_encoders)
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")