Switch from class-level params to init args
This commit is contained in:
@@ -31,7 +31,9 @@ from cassandra.decoder import (QueryMessage, ResultMessage,
|
||||
PrepareMessage, ExecuteMessage,
|
||||
PreparedQueryNotFound,
|
||||
IsBootstrappingErrorMessage,
|
||||
BatchMessage)
|
||||
BatchMessage, RESULT_KIND_PREPARED,
|
||||
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
|
||||
RESULT_KIND_SCHEMA_CHANGE)
|
||||
from cassandra.metadata import Metadata
|
||||
from cassandra.metrics import Metrics
|
||||
from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy,
|
||||
@@ -855,7 +857,7 @@ class Cluster(object):
|
||||
responses = connection.wait_for_responses(*messages, timeout=5.0)
|
||||
for response in responses:
|
||||
if (not isinstance(response, ResultMessage) or
|
||||
response.kind != ResultMessage.KIND_PREPARED):
|
||||
response.kind != RESULT_KIND_PREPARED):
|
||||
log.debug("Got unexpected response when preparing "
|
||||
"statement on host %s: %r", host, response)
|
||||
|
||||
@@ -1068,18 +1070,15 @@ class Session(object):
|
||||
if parameters:
|
||||
parameters = encode_params(parameters)
|
||||
message = QueryMessage(
|
||||
query=query.query_string, consistency_level=query.consistency_level,
|
||||
parameters=parameters)
|
||||
query.query_string, query.consistency_level, parameters)
|
||||
else:
|
||||
query_string = query.query_string
|
||||
if parameters:
|
||||
query_string = bind_params(query.query_string, parameters)
|
||||
message = QueryMessage(query=query_string, consistency_level=query.consistency_level)
|
||||
message = QueryMessage(query_string, query.consistency_level)
|
||||
elif isinstance(query, BoundStatement):
|
||||
message = ExecuteMessage(
|
||||
query_id=query.prepared_statement.query_id,
|
||||
query_params=query.values,
|
||||
consistency_level=query.consistency_level)
|
||||
query.prepared_statement.query_id, query.values, query.consistency_level)
|
||||
prepared_statement = query.prepared_statement
|
||||
elif isinstance(query, BatchStatement):
|
||||
if self._protocol_version < 2:
|
||||
@@ -1088,9 +1087,8 @@ class Session(object):
|
||||
"2 or higher (supported in Cassandra 2.0 and higher). Consider "
|
||||
"setting Cluster.protocol_version to 2 to support this operation.")
|
||||
message = BatchMessage(
|
||||
batch_type=query.batch_type,
|
||||
queries=query._statements_and_values,
|
||||
consistency_level=query.consistency_level)
|
||||
query.batch_type, query._statements_and_parameters,
|
||||
query.consistency_level)
|
||||
|
||||
if trace:
|
||||
message.tracing = True
|
||||
@@ -1949,7 +1947,7 @@ class ResponseFuture(object):
|
||||
self._query_trace = QueryTrace(trace_id, self.session)
|
||||
|
||||
if isinstance(response, ResultMessage):
|
||||
if response.kind == ResultMessage.KIND_SET_KEYSPACE:
|
||||
if response.kind == RESULT_KIND_SET_KEYSPACE:
|
||||
session = getattr(self, 'session', None)
|
||||
# since we're running on the event loop thread, we need to
|
||||
# use a non-blocking method for setting the keyspace on
|
||||
@@ -1961,7 +1959,7 @@ class ResponseFuture(object):
|
||||
if session:
|
||||
session._set_keyspace_for_all_pools(
|
||||
response.results, self._set_keyspace_completed)
|
||||
elif response.kind == ResultMessage.KIND_SCHEMA_CHANGE:
|
||||
elif response.kind == RESULT_KIND_SCHEMA_CHANGE:
|
||||
# refresh the schema before responding, but do it in another
|
||||
# thread instead of the event loop thread
|
||||
self.session.submit(
|
||||
@@ -1972,7 +1970,7 @@ class ResponseFuture(object):
|
||||
self)
|
||||
else:
|
||||
results = getattr(response, 'results', None)
|
||||
if results is not None and response.kind == ResultMessage.KIND_ROWS:
|
||||
if results is not None and response.kind == RESULT_KIND_ROWS:
|
||||
results = self.row_factory(*results)
|
||||
self._set_final_result(results)
|
||||
elif isinstance(response, ErrorMessage):
|
||||
@@ -2107,7 +2105,7 @@ class ResponseFuture(object):
|
||||
return
|
||||
|
||||
if isinstance(response, ResultMessage):
|
||||
if response.kind == ResultMessage.KIND_PREPARED:
|
||||
if response.kind == RESULT_KIND_PREPARED:
|
||||
# use self._query to re-use the same host and
|
||||
# at the same time properly borrow the connection
|
||||
request_id = self._query(self._current_host)
|
||||
|
||||
@@ -48,19 +48,9 @@ class _register_msg_type(type):
|
||||
|
||||
class _MessageType(object):
|
||||
__metaclass__ = _register_msg_type
|
||||
params = ()
|
||||
|
||||
tracing = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for pname in self.params:
|
||||
try:
|
||||
pval = kwargs[pname]
|
||||
except KeyError:
|
||||
raise ValueError("%s instances need the %s keyword parameter"
|
||||
% (self.__class__.__name__, pname))
|
||||
setattr(self, pname, pval)
|
||||
|
||||
def to_string(self, stream_id, protocol_version, compression=None):
|
||||
body = StringIO()
|
||||
self.send_body(body)
|
||||
@@ -77,11 +67,17 @@ class _MessageType(object):
|
||||
return ''.join(msg_parts)
|
||||
|
||||
def __str__(self):
|
||||
paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in self.params]
|
||||
paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in _get_params(self)]
|
||||
return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs))
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def _get_params(message_obj):
|
||||
base_attrs = dir(_MessageType)
|
||||
return [a for a in dir(message_obj)
|
||||
if a not in base_attrs and not a.startswith('_') and not callable(getattr(message_obj, a))]
|
||||
|
||||
|
||||
def decode_response(stream_id, flags, opcode, body, decompressor=None):
|
||||
if flags & 0x01:
|
||||
if decompressor is None:
|
||||
@@ -112,9 +108,13 @@ error_classes = {}
|
||||
class ErrorMessage(_MessageType, Exception):
|
||||
opcode = 0x00
|
||||
name = 'ERROR'
|
||||
params = ('code', 'message', 'info')
|
||||
summary = 'Unknown'
|
||||
|
||||
def __init__(self, code, message, info):
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.info = info
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f):
|
||||
code = read_int(f)
|
||||
@@ -290,13 +290,16 @@ class AlreadyExistsException(ConfigurationException):
|
||||
class StartupMessage(_MessageType):
|
||||
opcode = 0x01
|
||||
name = 'STARTUP'
|
||||
params = ('cqlversion', 'options')
|
||||
|
||||
KNOWN_OPTION_KEYS = set((
|
||||
'CQL_VERSION',
|
||||
'COMPRESSION',
|
||||
))
|
||||
|
||||
def __init__(self, cqlversion, options):
|
||||
self.cqlversion = cqlversion
|
||||
self.options = options
|
||||
|
||||
def send_body(self, f):
|
||||
optmap = self.options.copy()
|
||||
optmap['CQL_VERSION'] = self.cqlversion
|
||||
@@ -306,7 +309,6 @@ class StartupMessage(_MessageType):
|
||||
class ReadyMessage(_MessageType):
|
||||
opcode = 0x02
|
||||
name = 'READY'
|
||||
params = ()
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f):
|
||||
@@ -316,7 +318,9 @@ class ReadyMessage(_MessageType):
|
||||
class AuthenticateMessage(_MessageType):
|
||||
opcode = 0x03
|
||||
name = 'AUTHENTICATE'
|
||||
params = ('authenticator',)
|
||||
|
||||
def __init__(self, authenticator):
|
||||
self.authenticator = authenticator
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f):
|
||||
@@ -327,7 +331,9 @@ class AuthenticateMessage(_MessageType):
|
||||
class CredentialsMessage(_MessageType):
|
||||
opcode = 0x04
|
||||
name = 'CREDENTIALS'
|
||||
params = ('creds',)
|
||||
|
||||
def __init__(self, creds):
|
||||
self.creds = creds
|
||||
|
||||
def send_body(self, f):
|
||||
write_short(f, len(self.creds))
|
||||
@@ -339,7 +345,6 @@ class CredentialsMessage(_MessageType):
|
||||
class OptionsMessage(_MessageType):
|
||||
opcode = 0x05
|
||||
name = 'OPTIONS'
|
||||
params = ()
|
||||
|
||||
def send_body(self, f):
|
||||
pass
|
||||
@@ -348,7 +353,10 @@ class OptionsMessage(_MessageType):
|
||||
class SupportedMessage(_MessageType):
|
||||
opcode = 0x06
|
||||
name = 'SUPPORTED'
|
||||
params = ('cql_versions', 'options',)
|
||||
|
||||
def __init__(self, cql_versions, options):
|
||||
self.cql_versions = cql_versions
|
||||
self.options = options
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f):
|
||||
@@ -360,7 +368,17 @@ class SupportedMessage(_MessageType):
|
||||
class QueryMessage(_MessageType):
|
||||
opcode = 0x07
|
||||
name = 'QUERY'
|
||||
params = ('query', 'consistency_level',)
|
||||
|
||||
_VALUES_FLAG = 0x01
|
||||
_SKIP_METADATA_FLAG = 0x01
|
||||
_PAGE_SIZE_FLAG = 0x04
|
||||
_WITH_PAGING_STATE_SIZE_FLAG = 0x08
|
||||
_WITH_SERIAL_CONSISTENCY_FLAG = 0x10
|
||||
|
||||
def __init__(self, query, consistency_level, values=None):
|
||||
self.query = query
|
||||
self.consistency_level = consistency_level
|
||||
self.values = values
|
||||
|
||||
def send_body(self, f):
|
||||
write_longstring(f, self.query)
|
||||
@@ -369,19 +387,18 @@ class QueryMessage(_MessageType):
|
||||
|
||||
CUSTOM_TYPE = object()
|
||||
|
||||
RESULT_KIND_VOID = 0x0001
|
||||
RESULT_KIND_ROWS = 0x0002
|
||||
RESULT_KIND_SET_KEYSPACE = 0x0003
|
||||
RESULT_KIND_PREPARED = 0x0004
|
||||
RESULT_KIND_SCHEMA_CHANGE = 0x0005
|
||||
|
||||
|
||||
class ResultMessage(_MessageType):
|
||||
opcode = 0x08
|
||||
name = 'RESULT'
|
||||
params = ('kind', 'results',)
|
||||
|
||||
KIND_VOID = 0x0001
|
||||
KIND_ROWS = 0x0002
|
||||
KIND_SET_KEYSPACE = 0x0003
|
||||
KIND_PREPARED = 0x0004
|
||||
KIND_SCHEMA_CHANGE = 0x0005
|
||||
|
||||
type_codes = {
|
||||
_type_codes = {
|
||||
0x0000: CUSTOM_TYPE,
|
||||
0x0001: AsciiType,
|
||||
0x0002: LongType,
|
||||
@@ -404,21 +421,25 @@ class ResultMessage(_MessageType):
|
||||
0x0022: SetType,
|
||||
}
|
||||
|
||||
FLAGS_GLOBAL_TABLES_SPEC = 0x0001
|
||||
_FLAGS_GLOBAL_TABLES_SPEC = 0x0001
|
||||
|
||||
def __init__(self, kind, results):
|
||||
self.kind = kind
|
||||
self.results = results
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f):
|
||||
kind = read_int(f)
|
||||
if kind == cls.KIND_VOID:
|
||||
if kind == RESULT_KIND_VOID:
|
||||
results = None
|
||||
elif kind == cls.KIND_ROWS:
|
||||
elif kind == RESULT_KIND_ROWS:
|
||||
results = cls.recv_results_rows(f)
|
||||
elif kind == cls.KIND_SET_KEYSPACE:
|
||||
elif kind == RESULT_KIND_SET_KEYSPACE:
|
||||
ksname = read_string(f)
|
||||
results = ksname
|
||||
elif kind == cls.KIND_PREPARED:
|
||||
elif kind == RESULT_KIND_PREPARED:
|
||||
results = cls.recv_results_prepared(f)
|
||||
elif kind == cls.KIND_SCHEMA_CHANGE:
|
||||
elif kind == RESULT_KIND_SCHEMA_CHANGE:
|
||||
results = cls.recv_results_schema_change(f)
|
||||
return cls(kind=kind, results=results)
|
||||
|
||||
@@ -441,7 +462,7 @@ class ResultMessage(_MessageType):
|
||||
@classmethod
|
||||
def recv_results_metadata(cls, f):
|
||||
flags = read_int(f)
|
||||
glob_tblspec = bool(flags & cls.FLAGS_GLOBAL_TABLES_SPEC)
|
||||
glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC)
|
||||
colcount = read_int(f)
|
||||
if glob_tblspec:
|
||||
ksname = read_string(f)
|
||||
@@ -470,7 +491,7 @@ class ResultMessage(_MessageType):
|
||||
def read_type(cls, f):
|
||||
optid = read_short(f)
|
||||
try:
|
||||
typeclass = cls.type_codes[optid]
|
||||
typeclass = cls._type_codes[optid]
|
||||
except KeyError:
|
||||
raise NotSupportedError("Unknown data type code 0x%04x. Have to skip"
|
||||
" entire result set." % (optid,))
|
||||
@@ -495,7 +516,9 @@ class ResultMessage(_MessageType):
|
||||
class PrepareMessage(_MessageType):
|
||||
opcode = 0x09
|
||||
name = 'PREPARE'
|
||||
params = ('query',)
|
||||
|
||||
def __init__(self, query):
|
||||
self.query = query
|
||||
|
||||
def send_body(self, f):
|
||||
write_longstring(f, self.query)
|
||||
@@ -504,7 +527,11 @@ class PrepareMessage(_MessageType):
|
||||
class ExecuteMessage(_MessageType):
|
||||
opcode = 0x0A
|
||||
name = 'EXECUTE'
|
||||
params = ('query_id', 'query_params', 'consistency_level',)
|
||||
|
||||
def __init__(self, query_id, query_params, consistency_level):
|
||||
self.query_id = query_id
|
||||
self.query_params = query_params
|
||||
self.consistency_level = consistency_level
|
||||
|
||||
def send_body(self, f):
|
||||
write_string(f, self.query_id)
|
||||
@@ -517,7 +544,11 @@ class ExecuteMessage(_MessageType):
|
||||
class BatchMessage(_MessageType):
|
||||
opcode = 0x0D
|
||||
name = 'BATCH'
|
||||
params = ('batch_type', 'queries', 'consistency_level',)
|
||||
|
||||
def __init__(self, batch_type, queries, consistency_level):
|
||||
self.batch_type = batch_type
|
||||
self.queries = queries
|
||||
self.consistency_level = consistency_level
|
||||
|
||||
def send_body(self, f):
|
||||
write_byte(f, self.batch_type.value)
|
||||
@@ -546,16 +577,21 @@ known_event_types = frozenset((
|
||||
class RegisterMessage(_MessageType):
|
||||
opcode = 0x0B
|
||||
name = 'REGISTER'
|
||||
params = ('event_list',)
|
||||
|
||||
def send_body(self, f):
|
||||
def __init__(self, event_list):
|
||||
self.event_list = event_list
|
||||
|
||||
def send_body(self, f, protocol_version):
|
||||
write_stringlist(f, self.event_list)
|
||||
|
||||
|
||||
class EventMessage(_MessageType):
|
||||
opcode = 0x0C
|
||||
name = 'EVENT'
|
||||
params = ('event_type', 'event_args')
|
||||
|
||||
def __init__(self, event_type, event_args):
|
||||
self.event_type = event_type
|
||||
self.event_args = event_args
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f):
|
||||
|
||||
@@ -8,7 +8,7 @@ from mock import Mock, ANY
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from cassandra import OperationTimedOut
|
||||
from cassandra.decoder import ResultMessage
|
||||
from cassandra.decoder import ResultMessage, RESULT_KIND_ROWS
|
||||
from cassandra.cluster import ControlConnection, Cluster, _Scheduler
|
||||
from cassandra.pool import Host
|
||||
from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy,
|
||||
@@ -16,6 +16,7 @@ from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy,
|
||||
|
||||
PEER_IP = "foobar"
|
||||
|
||||
|
||||
class MockMetadata(object):
|
||||
|
||||
def __init__(self):
|
||||
@@ -91,9 +92,9 @@ class MockConnection(object):
|
||||
|
||||
def wait_for_responses(self, peer_query, local_query, timeout=None):
|
||||
local_response = ResultMessage(
|
||||
kind=ResultMessage.KIND_ROWS, results=self.local_results)
|
||||
kind=RESULT_KIND_ROWS, results=self.local_results)
|
||||
peer_response = ResultMessage(
|
||||
kind=ResultMessage.KIND_ROWS, results=self.peer_results)
|
||||
kind=RESULT_KIND_ROWS, results=self.peer_results)
|
||||
return (peer_response, local_response)
|
||||
|
||||
|
||||
|
||||
@@ -11,11 +11,14 @@ from cassandra.connection import ConnectionException
|
||||
from cassandra.decoder import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage,
|
||||
UnavailableErrorMessage, ResultMessage, QueryMessage,
|
||||
OverloadedErrorMessage, IsBootstrappingErrorMessage,
|
||||
PreparedQueryNotFound, PrepareMessage)
|
||||
PreparedQueryNotFound, PrepareMessage,
|
||||
RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE,
|
||||
RESULT_KIND_SCHEMA_CHANGE)
|
||||
from cassandra.policies import RetryPolicy
|
||||
from cassandra.pool import NoConnectionsAvailable
|
||||
from cassandra.query import SimpleStatement
|
||||
|
||||
|
||||
class ResponseFutureTests(unittest.TestCase):
|
||||
|
||||
def make_basic_session(self):
|
||||
@@ -46,7 +49,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
connection = pool.borrow_connection.return_value
|
||||
connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
|
||||
|
||||
response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}])
|
||||
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
|
||||
rf._set_result(response)
|
||||
|
||||
result = rf.result()
|
||||
@@ -65,7 +68,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
rf.send_request()
|
||||
|
||||
result = Mock(spec=ResultMessage,
|
||||
kind=ResultMessage.KIND_SET_KEYSPACE,
|
||||
kind=RESULT_KIND_SET_KEYSPACE,
|
||||
results="keyspace1")
|
||||
rf._set_result(result)
|
||||
rf._set_keyspace_completed({})
|
||||
@@ -77,7 +80,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
rf.send_request()
|
||||
|
||||
result = Mock(spec=ResultMessage,
|
||||
kind=ResultMessage.KIND_SCHEMA_CHANGE,
|
||||
kind=RESULT_KIND_SCHEMA_CHANGE,
|
||||
results={'keyspace': "keyspace1", "table": "table1"})
|
||||
rf._set_result(result)
|
||||
session.submit.assert_called_once_with(ANY, 'keyspace1', 'table1', ANY, rf)
|
||||
@@ -256,7 +259,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
rf = self.make_response_future(session)
|
||||
rf.send_request()
|
||||
|
||||
response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}])
|
||||
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
|
||||
rf._set_result(response)
|
||||
|
||||
result = rf.result()
|
||||
@@ -277,7 +280,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
rf = self.make_response_future(session)
|
||||
rf.send_request()
|
||||
|
||||
response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}])
|
||||
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
|
||||
rf._set_result(response)
|
||||
self.assertEqual(rf.result(), [{'col': 'val'}])
|
||||
|
||||
@@ -291,7 +294,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
|
||||
rf.add_callback(self.assertEqual, [{'col': 'val'}])
|
||||
|
||||
response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}])
|
||||
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
|
||||
rf._set_result(response)
|
||||
|
||||
result = rf.result()
|
||||
@@ -346,7 +349,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
callback=self.assertEquals, callback_args=([{'col': 'val'}],),
|
||||
errback=self.assertIsInstance, errback_args=(Exception,))
|
||||
|
||||
response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}])
|
||||
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
|
||||
rf._set_result(response)
|
||||
self.assertEqual(rf.result(), [{'col': 'val'}])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user