Use registered UDTs for non-prepared encoding
This commit is contained in:
@@ -44,6 +44,7 @@ from itertools import groupby
|
||||
from cassandra import (ConsistencyLevel, AuthenticationFailed,
|
||||
OperationTimedOut, UnsupportedOperation)
|
||||
from cassandra.connection import ConnectionException, ConnectionShutdown
|
||||
from cassandra.encoder import cql_encode_all_types, cql_encoders
|
||||
from cassandra.protocol import (QueryMessage, ResultMessage,
|
||||
ErrorMessage, ReadTimeoutErrorMessage,
|
||||
WriteTimeoutErrorMessage,
|
||||
@@ -409,8 +410,7 @@ class Cluster(object):
|
||||
self._listener_lock = Lock()
|
||||
|
||||
# let Session objects be GC'ed (and shutdown) when the user no longer
|
||||
# holds a reference. Normally the cycle detector would handle this,
|
||||
# but implementing __del__ prevents that.
|
||||
# holds a reference.
|
||||
self.sessions = WeakSet()
|
||||
self.metadata = Metadata(self)
|
||||
self.control_connection = None
|
||||
@@ -451,8 +451,10 @@ class Cluster(object):
|
||||
self.control_connection = ControlConnection(
|
||||
self, self.control_connection_timeout)
|
||||
|
||||
def register_type_class(self, keyspace, user_type, klass):
|
||||
def register_user_type(self, keyspace, user_type, klass):
|
||||
self._user_types[keyspace][user_type] = klass
|
||||
for session in self.sessions:
|
||||
self.session.user_type_registered(keyspace, user_type, klass)
|
||||
|
||||
def get_min_requests_per_connection(self, host_distance):
|
||||
return self._min_requests_per_connection[host_distance]
|
||||
@@ -602,6 +604,9 @@ class Cluster(object):
|
||||
|
||||
def _new_session(self):
|
||||
session = Session(self, self.metadata.all_hosts())
|
||||
for keyspace, type_map in six.iteritems(self._user_types):
|
||||
for udt_name, klass in six.iteritems(type_map):
|
||||
session.user_type_registered(keyspace, udt_name, klass)
|
||||
self.sessions.add(session)
|
||||
return session
|
||||
|
||||
@@ -1064,6 +1069,19 @@ class Session(object):
|
||||
_metrics = None
|
||||
_protocol_version = None
|
||||
|
||||
encoders = None
|
||||
|
||||
def user_type_registered(self, keyspace, user_type, klass):
|
||||
type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type]
|
||||
|
||||
def encode(val):
|
||||
return '{ %s }' % ' , '.join('%s : %s' % (
|
||||
field_name,
|
||||
cql_encode_all_types(getattr(val, field_name))
|
||||
) for field_name in type_meta.field_names)
|
||||
|
||||
self._encoders[klass] = encode
|
||||
|
||||
def __init__(self, cluster, hosts):
|
||||
self.cluster = cluster
|
||||
self.hosts = hosts
|
||||
@@ -1074,6 +1092,8 @@ class Session(object):
|
||||
self._metrics = cluster.metrics
|
||||
self._protocol_version = self.cluster.protocol_version
|
||||
|
||||
self._encoders = cql_encoders.copy()
|
||||
|
||||
# create connection pools in parallel
|
||||
futures = []
|
||||
for host in hosts:
|
||||
@@ -1196,7 +1216,7 @@ class Session(object):
|
||||
if isinstance(query, SimpleStatement):
|
||||
query_string = query.query_string
|
||||
if parameters:
|
||||
query_string = bind_params(query.query_string, parameters)
|
||||
query_string = bind_params(query.query_string, parameters, self._encoders)
|
||||
message = QueryMessage(
|
||||
query_string, cl, query.serial_consistency_level,
|
||||
fetch_size, timestamp=timestamp)
|
||||
@@ -1701,8 +1721,8 @@ class ControlConnection(object):
|
||||
cf_query, col_query)
|
||||
|
||||
log.debug("[control connection] Fetched table info for %s.%s, rebuilding metadata", (keyspace, table))
|
||||
cf_result = dict_factory(*cf_result.results)
|
||||
col_result = dict_factory(*col_result.results)
|
||||
cf_result = dict_factory(*cf_result.results) if cf_result else {}
|
||||
col_result = dict_factory(*col_result.results) if col_result else {}
|
||||
self._cluster.metadata.table_changed(keyspace, table, cf_result, col_result)
|
||||
elif usertype:
|
||||
# user defined types within this keyspace changed
|
||||
@@ -1710,7 +1730,7 @@ class ControlConnection(object):
|
||||
types_query = QueryMessage(query=self._SELECT_USERTYPES + where_clause, consistency_level=cl)
|
||||
types_result = connection.wait_for_response(types_query)
|
||||
log.debug("[control connection] Fetched user type info for %s.%s, rebuilding metadata", (keyspace, usertype))
|
||||
types_result = dict_factory(*types_result)
|
||||
types_result = dict_factory(*types_result.results) if types_result.results else {}
|
||||
self._cluster.metadata.usertype_changed(keyspace, usertype, types_result)
|
||||
elif keyspace:
|
||||
# only the keyspace itself changed (such as replication settings)
|
||||
@@ -1718,7 +1738,7 @@ class ControlConnection(object):
|
||||
ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl)
|
||||
ks_result = connection.wait_for_response(ks_query)
|
||||
log.debug("[control connection] Fetched keyspace info for %s, rebuilding metadata", (keyspace,))
|
||||
ks_result = dict_factory(*types_result)
|
||||
ks_result = dict_factory(*ks_result.results) if ks_result.results else {}
|
||||
self._cluster.metadata.keyspace_changed(keyspace, ks_result)
|
||||
else:
|
||||
# build everything from scratch
|
||||
@@ -1730,12 +1750,12 @@ class ControlConnection(object):
|
||||
if self._protocol_version >= 3:
|
||||
queries.append(QueryMessage(query=self._SELECT_USERTYPES, consistency_level=cl))
|
||||
ks_result, cf_result, col_result, types_result = connection.wait_for_responses(*queries)
|
||||
types_result = dict_factory(*types_result)
|
||||
types_result = dict_factory(*types_result.results) if types_result.results else {}
|
||||
else:
|
||||
ks_result, cf_result, col_result = connection.wait_for_responses(*queries)
|
||||
types_result = {}
|
||||
|
||||
ks_result = dict_factory(*types_result)
|
||||
ks_result = dict_factory(*ks_result.results)
|
||||
cf_result = dict_factory(*cf_result.results)
|
||||
col_result = dict_factory(*col_result.results)
|
||||
|
||||
|
@@ -560,14 +560,14 @@ class ResultMessage(_MessageType):
|
||||
ksname = read_string(f)
|
||||
results = ksname
|
||||
elif kind == RESULT_KIND_PREPARED:
|
||||
results = cls.recv_results_prepared(f)
|
||||
results = cls.recv_results_prepared(f, user_type_map)
|
||||
elif kind == RESULT_KIND_SCHEMA_CHANGE:
|
||||
results = cls.recv_results_schema_change(f, protocol_version)
|
||||
return cls(kind, results, paging_state)
|
||||
|
||||
@classmethod
|
||||
def recv_results_rows(cls, f, protocol_version, user_type_map):
|
||||
paging_state, column_metadata = cls.recv_results_metadata(f)
|
||||
paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map)
|
||||
rowcount = read_int(f)
|
||||
rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
|
||||
colnames = [c[2] for c in column_metadata]
|
||||
@@ -579,9 +579,9 @@ class ResultMessage(_MessageType):
|
||||
return (paging_state, (colnames, parsed_rows))
|
||||
|
||||
@classmethod
|
||||
def recv_results_prepared(cls, f):
|
||||
def recv_results_prepared(cls, f, user_type_map):
|
||||
query_id = read_binary_string(f)
|
||||
_, column_metadata = cls.recv_results_metadata(f)
|
||||
_, column_metadata = cls.recv_results_metadata(f, user_type_map)
|
||||
return (query_id, column_metadata)
|
||||
|
||||
@classmethod
|
||||
|
@@ -567,9 +567,10 @@ class BatchStatement(Statement):
|
||||
"""
|
||||
|
||||
_statements_and_parameters = None
|
||||
_session = None
|
||||
|
||||
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
|
||||
consistency_level=None):
|
||||
consistency_level=None, session=None):
|
||||
"""
|
||||
`batch_type` specifies The :class:`.BatchType` for the batch operation.
|
||||
Defaults to :attr:`.BatchType.LOGGED`.
|
||||
@@ -605,6 +606,7 @@ class BatchStatement(Statement):
|
||||
"""
|
||||
self.batch_type = batch_type
|
||||
self._statements_and_parameters = []
|
||||
self._session = session
|
||||
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level)
|
||||
|
||||
def add(self, statement, parameters=None):
|
||||
@@ -617,7 +619,8 @@ class BatchStatement(Statement):
|
||||
"""
|
||||
if isinstance(statement, six.string_types):
|
||||
if parameters:
|
||||
statement = bind_params(statement, parameters)
|
||||
encoders = cql_encoders if self._session is None else self._session.encoders
|
||||
statement = bind_params(statement, parameters, encoders)
|
||||
self._statements_and_parameters.append((False, statement, ()))
|
||||
elif isinstance(statement, PreparedStatement):
|
||||
query_id = statement.query_id
|
||||
@@ -635,7 +638,8 @@ class BatchStatement(Statement):
|
||||
# it must be a SimpleStatement
|
||||
query_string = statement.query_string
|
||||
if parameters:
|
||||
query_string = bind_params(query_string, parameters)
|
||||
encoders = cql_encoders if self._session is None else self._session.encoders
|
||||
query_string = bind_params(query_string, parameters, encoders)
|
||||
self._statements_and_parameters.append((False, query_string, ()))
|
||||
return self
|
||||
|
||||
@@ -677,11 +681,11 @@ class ValueSequence(object):
|
||||
return cql_encode_sequence(self.sequence)
|
||||
|
||||
|
||||
def bind_params(query, params):
|
||||
def bind_params(query, params, encoders):
|
||||
if isinstance(params, dict):
|
||||
return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
|
||||
return query % dict((k, encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
|
||||
else:
|
||||
return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v) for v in params)
|
||||
return query % tuple(encoders.get(type(v), cql_encode_object)(v) for v in params)
|
||||
|
||||
|
||||
class TraceUnavailable(Exception):
|
||||
|
@@ -17,6 +17,7 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.encoder import cql_encoders
|
||||
from cassandra.query import bind_params, ValueSequence
|
||||
from cassandra.query import PreparedStatement, BoundStatement
|
||||
from cassandra.cqltypes import Int32Type
|
||||
@@ -28,31 +29,31 @@ from six.moves import xrange
|
||||
class ParamBindingTest(unittest.TestCase):
|
||||
|
||||
def test_bind_sequence(self):
|
||||
result = bind_params("%s %s %s", (1, "a", 2.0))
|
||||
result = bind_params("%s %s %s", (1, "a", 2.0), cql_encoders)
|
||||
self.assertEqual(result, "1 'a' 2.0")
|
||||
|
||||
def test_bind_map(self):
|
||||
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0))
|
||||
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), cql_encoders)
|
||||
self.assertEqual(result, "1 'a' 2.0")
|
||||
|
||||
def test_sequence_param(self):
|
||||
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),))
|
||||
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), cql_encoders)
|
||||
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
|
||||
|
||||
def test_generator_param(self):
|
||||
result = bind_params("%s", ((i for i in xrange(3)),))
|
||||
result = bind_params("%s", ((i for i in xrange(3)),), cql_encoders)
|
||||
self.assertEqual(result, "[ 0 , 1 , 2 ]")
|
||||
|
||||
def test_none_param(self):
|
||||
result = bind_params("%s", (None,))
|
||||
result = bind_params("%s", (None,), cql_encoders)
|
||||
self.assertEqual(result, "NULL")
|
||||
|
||||
def test_list_collection(self):
|
||||
result = bind_params("%s", (['a', 'b', 'c'],))
|
||||
result = bind_params("%s", (['a', 'b', 'c'],), cql_encoders)
|
||||
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
|
||||
|
||||
def test_set_collection(self):
|
||||
result = bind_params("%s", (set(['a', 'b']),))
|
||||
result = bind_params("%s", (set(['a', 'b']),), cql_encoders)
|
||||
self.assertIn(result, ("{ 'a' , 'b' }", "{ 'b' , 'a' }"))
|
||||
|
||||
def test_map_collection(self):
|
||||
@@ -60,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
|
||||
vals['a'] = 'a'
|
||||
vals['b'] = 'b'
|
||||
vals['c'] = 'c'
|
||||
result = bind_params("%s", (vals,))
|
||||
result = bind_params("%s", (vals,), cql_encoders)
|
||||
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
|
||||
|
||||
def test_quote_escaping(self):
|
||||
result = bind_params("%s", ("""'ef''ef"ef""ef'""",))
|
||||
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), cql_encoders)
|
||||
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user