Merge branch 'master' of github.com:datastax/python-driver into PYTHON-96-test
This commit is contained in:
@@ -5,6 +5,15 @@ In Progress
|
||||
Bug Fixes
|
||||
---------
|
||||
* Properly specify UDTs for columns in CREATE TABLE statements
|
||||
* Avoid moving retries to a new host when using request ID zero (PYTHON-88)
|
||||
* Don't ignore fetch_size arguments to Statement constructors (github-151)
|
||||
* Allow disabling automatic paging on a per-statement basis when it's
|
||||
enabled by default for the session (PYTHON-93)
|
||||
* Raise ValueError when tuple query parameters for prepared statements
|
||||
have extra items (PYTHON-98)
|
||||
* Correctly encode nested tuples and UDTs for non-prepared statements (PYTHON-100)
|
||||
* Raise TypeError when a string is used for contact_points (github #164)
|
||||
* Include User Defined Types in KeyspaceMetadata.export_as_string() (PYTHON-96)
|
||||
|
||||
Other
|
||||
-----
|
||||
|
||||
@@ -4,6 +4,10 @@ Releasing
|
||||
* If dependencies have changed, make sure ``debian/control``
|
||||
is up to date
|
||||
* Make sure all patches in ``debian/patches`` still apply cleanly
|
||||
* Update the debian changelog with the new version::
|
||||
|
||||
dch -v '1.0.0'
|
||||
|
||||
* Update CHANGELOG.rst
|
||||
* Update the version in ``cassandra/__init__.py``
|
||||
* Commit the changelog and version changes
|
||||
|
||||
@@ -42,9 +42,9 @@ from functools import partial, wraps
|
||||
from itertools import groupby
|
||||
|
||||
from cassandra import (ConsistencyLevel, AuthenticationFailed,
|
||||
OperationTimedOut, UnsupportedOperation)
|
||||
InvalidRequest, OperationTimedOut, UnsupportedOperation)
|
||||
from cassandra.connection import ConnectionException, ConnectionShutdown
|
||||
from cassandra.encoder import cql_encode_all_types, cql_encoders
|
||||
from cassandra.encoder import Encoder
|
||||
from cassandra.protocol import (QueryMessage, ResultMessage,
|
||||
ErrorMessage, ReadTimeoutErrorMessage,
|
||||
WriteTimeoutErrorMessage,
|
||||
@@ -65,7 +65,7 @@ from cassandra.pool import (_ReconnectionHandler, _HostReconnectionHandler,
|
||||
NoConnectionsAvailable)
|
||||
from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
|
||||
BatchStatement, bind_params, QueryTrace, Statement,
|
||||
named_tuple_factory, dict_factory)
|
||||
named_tuple_factory, dict_factory, FETCH_SIZE_UNSET)
|
||||
|
||||
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the
|
||||
# default because it's fastest. Otherwise, use asyncore.
|
||||
@@ -380,7 +380,12 @@ class Cluster(object):
|
||||
Any of the mutable Cluster attributes may be set as keyword arguments
|
||||
to the constructor.
|
||||
"""
|
||||
self.contact_points = contact_points
|
||||
if contact_points is not None:
|
||||
if isinstance(contact_points, six.string_types):
|
||||
raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings")
|
||||
|
||||
self.contact_points = contact_points
|
||||
|
||||
self.port = port
|
||||
self.compression = compression
|
||||
self.protocol_version = protocol_version
|
||||
@@ -467,6 +472,53 @@ class Cluster(object):
|
||||
self, self.control_connection_timeout)
|
||||
|
||||
def register_user_type(self, keyspace, user_type, klass):
|
||||
"""
|
||||
Registers a class to use to represent a particular user-defined type.
|
||||
Query parameters for this user-defined type will be assumed to be
|
||||
instances of `klass`. Result sets for this user-defined type will
|
||||
be instances of `klass`. If no class is registered for a user-defined
|
||||
type, a namedtuple will be used for result sets, and non-prepared
|
||||
statements may not encode parameters for this type correctly.
|
||||
|
||||
`keyspace` is the name of the keyspace that the UDT is defined in.
|
||||
|
||||
`user_type` is the string name of the UDT to register the mapping
|
||||
for.
|
||||
|
||||
`klass` should be a class with attributes whose names match the
|
||||
fields of the user-defined type. The constructor must accepts kwargs
|
||||
for each of the fields in the UDT.
|
||||
|
||||
This method should only be called after the type has been created
|
||||
within Cassandra.
|
||||
|
||||
Example::
|
||||
|
||||
cluster = Cluster(protocol_version=3)
|
||||
session = cluster.connect()
|
||||
session.set_keyspace('mykeyspace')
|
||||
session.execute("CREATE TYPE address (street text, zipcode int)")
|
||||
session.execute("CREATE TABLE users (id int PRIMARY KEY, location address)")
|
||||
|
||||
# create a class to map to the "address" UDT
|
||||
class Address(object):
|
||||
|
||||
def __init__(self, street, zipcode):
|
||||
self.street = street
|
||||
self.zipcode = zipcode
|
||||
|
||||
cluster.register_user_type('mykeyspace', 'address', Address)
|
||||
|
||||
# insert a row using an instance of Address
|
||||
session.execute("INSERT INTO users (id, location) VALUES (%s, %s)",
|
||||
(0, Address("123 Main St.", 78723)))
|
||||
|
||||
# results will include Address instances
|
||||
results = session.execute("SELECT * FROM users")
|
||||
row = results[0]
|
||||
print row.id, row.location.street, row.location.zipcode
|
||||
|
||||
"""
|
||||
self._user_types[keyspace][user_type] = klass
|
||||
for session in self.sessions:
|
||||
session.user_type_registered(keyspace, user_type, klass)
|
||||
@@ -591,6 +643,8 @@ class Cluster(object):
|
||||
raise Exception("Cluster is already shut down")
|
||||
|
||||
if not self._is_setup:
|
||||
log.debug("Connecting to cluster, contact points: %s; protocol version: %s",
|
||||
self.contact_points, self.protocol_version)
|
||||
self.connection_class.initialize_reactor()
|
||||
atexit.register(partial(_shutdown_cluster, self))
|
||||
for address in self.contact_points:
|
||||
@@ -1115,14 +1169,36 @@ class Session(object):
|
||||
.. versionadded:: 2.1.0
|
||||
"""
|
||||
|
||||
encoder = None
|
||||
"""
|
||||
A :class:`~cassandra.encoder.Encoder` instance that will be used when
|
||||
formatting query parameters for non-prepared statements. This is not used
|
||||
for prepared statements (because prepared statements give the driver more
|
||||
information about what CQL types are expected, allowing it to accept a
|
||||
wider range of python types).
|
||||
|
||||
The encoder uses a mapping from python types to encoder methods (for
|
||||
specific CQL types). This mapping can be be modified by users as they see
|
||||
fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping
|
||||
values if possible, because they take precautions to avoid injections and
|
||||
properly sanitize data.
|
||||
|
||||
Example::
|
||||
|
||||
cluster = Cluster()
|
||||
session = cluster.connect("mykeyspace")
|
||||
session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple
|
||||
|
||||
session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple<int, ascii>)")
|
||||
session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')])
|
||||
"""
|
||||
|
||||
_lock = None
|
||||
_pools = None
|
||||
_load_balancer = None
|
||||
_metrics = None
|
||||
_protocol_version = None
|
||||
|
||||
encoders = None
|
||||
|
||||
def __init__(self, cluster, hosts):
|
||||
self.cluster = cluster
|
||||
self.hosts = hosts
|
||||
@@ -1133,7 +1209,7 @@ class Session(object):
|
||||
self._metrics = cluster.metrics
|
||||
self._protocol_version = self.cluster.protocol_version
|
||||
|
||||
self.encoders = cql_encoders.copy()
|
||||
self.encoder = Encoder()
|
||||
|
||||
# create connection pools in parallel
|
||||
futures = []
|
||||
@@ -1246,8 +1322,10 @@ class Session(object):
|
||||
|
||||
cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level
|
||||
fetch_size = query.fetch_size
|
||||
if not fetch_size and self._protocol_version >= 2:
|
||||
if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2:
|
||||
fetch_size = self.default_fetch_size
|
||||
elif self._protocol_version == 1:
|
||||
fetch_size = None
|
||||
|
||||
if self._protocol_version >= 3 and self.use_client_timestamp:
|
||||
timestamp = int(time.time() * 1e6)
|
||||
@@ -1259,7 +1337,7 @@ class Session(object):
|
||||
if six.PY2 and isinstance(query_string, six.text_type):
|
||||
query_string = query_string.encode('utf-8')
|
||||
if parameters:
|
||||
query_string = bind_params(query_string, parameters, self.encoders)
|
||||
query_string = bind_params(query_string, parameters, self.encoder)
|
||||
message = QueryMessage(
|
||||
query_string, cl, query.serial_consistency_level,
|
||||
fetch_size, timestamp=timestamp)
|
||||
@@ -1504,15 +1582,25 @@ class Session(object):
|
||||
mapping from a user-defined type to a class. Intended for internal
|
||||
use only.
|
||||
"""
|
||||
type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type]
|
||||
try:
|
||||
ks_meta = self.cluster.metadata.keyspaces[keyspace]
|
||||
except KeyError:
|
||||
raise UserTypeDoesNotExist(
|
||||
'Keyspace %s does not exist or has not been discovered by the driver' % (keyspace,))
|
||||
|
||||
try:
|
||||
type_meta = ks_meta.user_types[user_type]
|
||||
except KeyError:
|
||||
raise UserTypeDoesNotExist(
|
||||
'User type %s does not exist in keyspace %s' % (user_type, keyspace))
|
||||
|
||||
def encode(val):
|
||||
return '{ %s }' % ' , '.join('%s : %s' % (
|
||||
field_name,
|
||||
cql_encode_all_types(getattr(val, field_name))
|
||||
self.encoder.cql_encode_all_types(getattr(val, field_name))
|
||||
) for field_name in type_meta.field_names)
|
||||
|
||||
self.encoders[klass] = encode
|
||||
self.encoder.mapping[klass] = encode
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
""" Internal """
|
||||
@@ -1523,6 +1611,13 @@ class Session(object):
|
||||
return dict((host, pool.get_state()) for host, pool in self._pools.items())
|
||||
|
||||
|
||||
class UserTypeDoesNotExist(Exception):
|
||||
"""
|
||||
An attempt was made to use a user-defined type that does not exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class _ControlReconnectionHandler(_ReconnectionHandler):
|
||||
"""
|
||||
Internal
|
||||
@@ -1807,19 +1902,37 @@ class ControlConnection(object):
|
||||
queries = [
|
||||
QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl),
|
||||
QueryMessage(query=self._SELECT_COLUMN_FAMILIES, consistency_level=cl),
|
||||
QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl)
|
||||
QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl),
|
||||
QueryMessage(query=self._SELECT_USERTYPES, consistency_level=cl)
|
||||
]
|
||||
if self._protocol_version >= 3:
|
||||
queries.append(QueryMessage(query=self._SELECT_USERTYPES, consistency_level=cl))
|
||||
ks_result, cf_result, col_result, types_result = connection.wait_for_responses(*queries)
|
||||
|
||||
responses = connection.wait_for_responses(*queries, fail_on_error=False)
|
||||
(ks_success, ks_result), (cf_success, cf_result), (col_success, col_result), (types_success, types_result) = responses
|
||||
|
||||
if ks_success:
|
||||
ks_result = dict_factory(*ks_result.results)
|
||||
else:
|
||||
raise ks_result
|
||||
|
||||
if cf_success:
|
||||
cf_result = dict_factory(*cf_result.results)
|
||||
else:
|
||||
raise cf_result
|
||||
|
||||
if col_success:
|
||||
col_result = dict_factory(*col_result.results)
|
||||
else:
|
||||
raise col_result
|
||||
|
||||
# if we're connected to Cassandra < 2.1, the usertypes table will not exist
|
||||
if types_success:
|
||||
types_result = dict_factory(*types_result.results) if types_result.results else {}
|
||||
else:
|
||||
ks_result, cf_result, col_result = connection.wait_for_responses(*queries)
|
||||
types_result = {}
|
||||
|
||||
ks_result = dict_factory(*ks_result.results)
|
||||
cf_result = dict_factory(*cf_result.results)
|
||||
col_result = dict_factory(*col_result.results)
|
||||
if isinstance(types_result, InvalidRequest):
|
||||
log.debug("[control connection] user types table not found")
|
||||
types_result = {}
|
||||
else:
|
||||
raise types_result
|
||||
|
||||
log.debug("[control connection] Fetched schema, rebuilding metadata")
|
||||
self._cluster.metadata.rebuild_schema(ks_result, types_result, cf_result, col_result)
|
||||
@@ -2558,7 +2671,7 @@ class ResponseFuture(object):
|
||||
# to retry the operation
|
||||
return
|
||||
|
||||
if reuse_connection and self._query(self._current_host):
|
||||
if reuse_connection and self._query(self._current_host) is not None:
|
||||
return
|
||||
|
||||
# otherwise, move onto another host
|
||||
|
||||
@@ -302,10 +302,19 @@ class Connection(object):
|
||||
return self.wait_for_responses(msg, timeout=timeout)[0]
|
||||
|
||||
def wait_for_responses(self, *msgs, **kwargs):
|
||||
"""
|
||||
Returns a list of (success, response) tuples. If success
|
||||
is False, response will be an Exception. Otherwise, response
|
||||
will be the normal query response.
|
||||
|
||||
If fail_on_error was left as True and one of the requests
|
||||
failed, the corresponding Exception will be raised.
|
||||
"""
|
||||
if self.is_closed or self.is_defunct:
|
||||
raise ConnectionShutdown("Connection %s is already closed" % (self, ))
|
||||
timeout = kwargs.get('timeout')
|
||||
waiter = ResponseWaiter(self, len(msgs))
|
||||
fail_on_error = kwargs.get('fail_on_error', True)
|
||||
waiter = ResponseWaiter(self, len(msgs), fail_on_error)
|
||||
|
||||
# busy wait for sufficient space on the connection
|
||||
messages_sent = 0
|
||||
@@ -677,9 +686,10 @@ class Connection(object):
|
||||
|
||||
class ResponseWaiter(object):
|
||||
|
||||
def __init__(self, connection, num_responses):
|
||||
def __init__(self, connection, num_responses, fail_on_error):
|
||||
self.connection = connection
|
||||
self.pending = num_responses
|
||||
self.fail_on_error = fail_on_error
|
||||
self.error = None
|
||||
self.responses = [None] * num_responses
|
||||
self.event = Event()
|
||||
@@ -688,15 +698,33 @@ class ResponseWaiter(object):
|
||||
with self.connection.lock:
|
||||
self.connection.in_flight -= 1
|
||||
if isinstance(response, Exception):
|
||||
self.error = response
|
||||
self.event.set()
|
||||
else:
|
||||
self.responses[index] = response
|
||||
self.pending -= 1
|
||||
if not self.pending:
|
||||
if hasattr(response, 'to_exception'):
|
||||
response = response.to_exception()
|
||||
if self.fail_on_error:
|
||||
self.error = response
|
||||
self.event.set()
|
||||
else:
|
||||
self.responses[index] = (False, response)
|
||||
else:
|
||||
if not self.fail_on_error:
|
||||
self.responses[index] = (True, response)
|
||||
else:
|
||||
self.responses[index] = response
|
||||
|
||||
self.pending -= 1
|
||||
if not self.pending:
|
||||
self.event.set()
|
||||
|
||||
def deliver(self, timeout=None):
|
||||
"""
|
||||
If fail_on_error was set to False, a list of (success, response)
|
||||
tuples will be returned. If success is False, response will be
|
||||
an Exception. Otherwise, response will be the normal query response.
|
||||
|
||||
If fail_on_error was left as True and one of the requests
|
||||
failed, the corresponding Exception will be raised. Otherwise,
|
||||
the normal response will be returned.
|
||||
"""
|
||||
self.event.wait(timeout)
|
||||
if self.error:
|
||||
raise self.error
|
||||
|
||||
@@ -806,6 +806,10 @@ class TupleType(_ParameterizedType):
|
||||
|
||||
@classmethod
|
||||
def serialize_safe(cls, val, protocol_version):
|
||||
if len(val) > len(cls.subtypes):
|
||||
raise ValueError("Expected %d items in a tuple, but got %d: %s" %
|
||||
(len(cls.subtypes), len(val), val))
|
||||
|
||||
proto_version = max(3, protocol_version)
|
||||
buf = io.BytesIO()
|
||||
for item, subtype in zip(val, cls.subtypes):
|
||||
|
||||
@@ -11,6 +11,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
These functions are used to convert Python objects into CQL strings.
|
||||
When non-prepared statements are executed, these encoder functions are
|
||||
called on each query parameter.
|
||||
"""
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -44,104 +49,156 @@ def cql_quote(term):
|
||||
return str(term)
|
||||
|
||||
|
||||
def cql_encode_none(val):
|
||||
return 'NULL'
|
||||
|
||||
|
||||
def cql_encode_unicode(val):
|
||||
return cql_quote(val.encode('utf-8'))
|
||||
|
||||
|
||||
def cql_encode_str(val):
|
||||
return cql_quote(val)
|
||||
|
||||
|
||||
if six.PY3:
|
||||
def cql_encode_bytes(val):
|
||||
return (b'0x' + hexlify(val)).decode('utf-8')
|
||||
elif sys.version_info >= (2, 7):
|
||||
def cql_encode_bytes(val): # noqa
|
||||
return b'0x' + hexlify(val)
|
||||
else:
|
||||
# python 2.6 requires string or read-only buffer for hexlify
|
||||
def cql_encode_bytes(val): # noqa
|
||||
return b'0x' + hexlify(buffer(val))
|
||||
|
||||
|
||||
def cql_encode_object(val):
|
||||
return str(val)
|
||||
|
||||
|
||||
def cql_encode_datetime(val):
|
||||
timestamp = calendar.timegm(val.utctimetuple())
|
||||
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
|
||||
|
||||
|
||||
def cql_encode_date(val):
|
||||
return "'%s'" % val.strftime('%Y-%m-%d-0000')
|
||||
|
||||
|
||||
def cql_encode_sequence(val):
|
||||
return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v)
|
||||
for v in val)
|
||||
cql_encode_tuple = cql_encode_sequence
|
||||
|
||||
|
||||
def cql_encode_map_collection(val):
|
||||
return '{ %s }' % ' , '.join('%s : %s' % (
|
||||
cql_encode_all_types(k),
|
||||
cql_encode_all_types(v)
|
||||
) for k, v in six.iteritems(val))
|
||||
|
||||
|
||||
def cql_encode_list_collection(val):
|
||||
return '[ %s ]' % ' , '.join(map(cql_encode_all_types, val))
|
||||
|
||||
|
||||
def cql_encode_set_collection(val):
|
||||
return '{ %s }' % ' , '.join(map(cql_encode_all_types, val))
|
||||
|
||||
|
||||
def cql_encode_all_types(val):
|
||||
return cql_encoders.get(type(val), cql_encode_object)(val)
|
||||
|
||||
|
||||
cql_encoders = {
|
||||
float: cql_encode_object,
|
||||
bytearray: cql_encode_bytes,
|
||||
str: cql_encode_str,
|
||||
int: cql_encode_object,
|
||||
UUID: cql_encode_object,
|
||||
datetime.datetime: cql_encode_datetime,
|
||||
datetime.date: cql_encode_date,
|
||||
dict: cql_encode_map_collection,
|
||||
OrderedDict: cql_encode_map_collection,
|
||||
list: cql_encode_list_collection,
|
||||
tuple: cql_encode_list_collection,
|
||||
set: cql_encode_set_collection,
|
||||
frozenset: cql_encode_set_collection,
|
||||
types.GeneratorType: cql_encode_list_collection
|
||||
}
|
||||
|
||||
if six.PY2:
|
||||
cql_encoders.update({
|
||||
unicode: cql_encode_unicode,
|
||||
buffer: cql_encode_bytes,
|
||||
long: cql_encode_object,
|
||||
types.NoneType: cql_encode_none,
|
||||
})
|
||||
else:
|
||||
cql_encoders.update({
|
||||
memoryview: cql_encode_bytes,
|
||||
bytes: cql_encode_bytes,
|
||||
type(None): cql_encode_none,
|
||||
})
|
||||
|
||||
# sortedset is optional
|
||||
try:
|
||||
from blist import sortedset
|
||||
cql_encoders.update({
|
||||
sortedset: cql_encode_set_collection
|
||||
})
|
||||
except ImportError:
|
||||
class ValueSequence(list):
|
||||
pass
|
||||
|
||||
|
||||
class Encoder(object):
|
||||
"""
|
||||
A container for mapping python types to CQL string literals when working
|
||||
with non-prepared statements. The type :attr:`~.Encoder.mapping` can be
|
||||
directly customized by users.
|
||||
"""
|
||||
|
||||
mapping = None
|
||||
"""
|
||||
A map of python types to encoder functions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.mapping = {
|
||||
float: self.cql_encode_object,
|
||||
bytearray: self.cql_encode_bytes,
|
||||
str: self.cql_encode_str,
|
||||
int: self.cql_encode_object,
|
||||
UUID: self.cql_encode_object,
|
||||
datetime.datetime: self.cql_encode_datetime,
|
||||
datetime.date: self.cql_encode_date,
|
||||
dict: self.cql_encode_map_collection,
|
||||
OrderedDict: self.cql_encode_map_collection,
|
||||
list: self.cql_encode_list_collection,
|
||||
tuple: self.cql_encode_list_collection,
|
||||
set: self.cql_encode_set_collection,
|
||||
frozenset: self.cql_encode_set_collection,
|
||||
types.GeneratorType: self.cql_encode_list_collection,
|
||||
ValueSequence: self.cql_encode_sequence
|
||||
}
|
||||
|
||||
if six.PY2:
|
||||
self.mapping.update({
|
||||
unicode: self.cql_encode_unicode,
|
||||
buffer: self.cql_encode_bytes,
|
||||
long: self.cql_encode_object,
|
||||
types.NoneType: self.cql_encode_none,
|
||||
})
|
||||
else:
|
||||
self.mapping.update({
|
||||
memoryview: self.cql_encode_bytes,
|
||||
bytes: self.cql_encode_bytes,
|
||||
type(None): self.cql_encode_none,
|
||||
})
|
||||
|
||||
# sortedset is optional
|
||||
try:
|
||||
from blist import sortedset
|
||||
self.mapping.update({
|
||||
sortedset: self.cql_encode_set_collection
|
||||
})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def cql_encode_none(self, val):
|
||||
"""
|
||||
Converts :const:`None` to the string 'NULL'.
|
||||
"""
|
||||
return 'NULL'
|
||||
|
||||
def cql_encode_unicode(self, val):
|
||||
"""
|
||||
Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping.
|
||||
"""
|
||||
return cql_quote(val.encode('utf-8'))
|
||||
|
||||
def cql_encode_str(self, val):
|
||||
"""
|
||||
Escapes quotes in :class:`str` objects.
|
||||
"""
|
||||
return cql_quote(val)
|
||||
|
||||
if six.PY3:
|
||||
def cql_encode_bytes(self, val):
|
||||
return (b'0x' + hexlify(val)).decode('utf-8')
|
||||
elif sys.version_info >= (2, 7):
|
||||
def cql_encode_bytes(self, val): # noqa
|
||||
return b'0x' + hexlify(val)
|
||||
else:
|
||||
# python 2.6 requires string or read-only buffer for hexlify
|
||||
def cql_encode_bytes(self, val): # noqa
|
||||
return b'0x' + hexlify(buffer(val))
|
||||
|
||||
def cql_encode_object(self, val):
|
||||
"""
|
||||
Default encoder for all objects that do not have a specific encoder function
|
||||
registered. This function simply calls :meth:`str()` on the object.
|
||||
"""
|
||||
return str(val)
|
||||
|
||||
def cql_encode_datetime(self, val):
|
||||
"""
|
||||
Converts a :class:`datetime.datetime` object to a (string) integer timestamp
|
||||
with millisecond precision.
|
||||
"""
|
||||
timestamp = calendar.timegm(val.utctimetuple())
|
||||
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
|
||||
|
||||
def cql_encode_date(self, val):
|
||||
"""
|
||||
Converts a :class:`datetime.date` object to a string with format
|
||||
``YYYY-MM-DD-0000``.
|
||||
"""
|
||||
return "'%s'" % val.strftime('%Y-%m-%d-0000')
|
||||
|
||||
def cql_encode_sequence(self, val):
|
||||
"""
|
||||
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
|
||||
is suitable for ``IN`` value lists.
|
||||
"""
|
||||
return '( %s )' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v)
|
||||
for v in val)
|
||||
|
||||
cql_encode_tuple = cql_encode_sequence
|
||||
"""
|
||||
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
|
||||
is suitable for ``tuple`` type columns.
|
||||
"""
|
||||
|
||||
def cql_encode_map_collection(self, val):
|
||||
"""
|
||||
Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``.
|
||||
This is suitable for ``map`` type columns.
|
||||
"""
|
||||
return '{ %s }' % ' , '.join('%s : %s' % (
|
||||
self.mapping.get(type(k), self.cql_encode_object)(k),
|
||||
self.mapping.get(type(v), self.cql_encode_object)(v)
|
||||
) for k, v in six.iteritems(val))
|
||||
|
||||
def cql_encode_list_collection(self, val):
|
||||
"""
|
||||
Converts a sequence to a string of the form ``[item1, item2, ...]``. This
|
||||
is suitable for ``list`` type columns.
|
||||
"""
|
||||
return '[ %s ]' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
|
||||
|
||||
def cql_encode_set_collection(self, val):
|
||||
"""
|
||||
Converts a sequence to a string of the form ``{item1, item2, ...}``. This
|
||||
is suitable for ``set`` type columns.
|
||||
"""
|
||||
return '{ %s }' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
|
||||
|
||||
def cql_encode_all_types(self, val):
|
||||
"""
|
||||
Converts any type into a CQL string, defaulting to ``cql_encode_object``
|
||||
if :attr:`~Encoder.mapping` does not contain an entry for the type.
|
||||
"""
|
||||
return self.mapping.get(type(val), self.cql_encode_object)(val)
|
||||
|
||||
@@ -450,7 +450,7 @@ class SimpleStrategy(ReplicationStrategy):
|
||||
while len(hosts) < self.replication_factor and j < len(ring):
|
||||
token = ring[(i + j) % len(ring)]
|
||||
host = token_to_host_owner[token]
|
||||
if not host in hosts:
|
||||
if host not in hosts:
|
||||
hosts.append(host)
|
||||
j += 1
|
||||
|
||||
@@ -597,7 +597,7 @@ class KeyspaceMetadata(object):
|
||||
self.user_types = {}
|
||||
|
||||
def export_as_string(self):
|
||||
return "\n".join([self.as_cql_query()] + [t.export_as_string() for t in self.tables.values()])
|
||||
return "\n\n".join([self.as_cql_query()] + self.user_type_strings() + [t.export_as_string() for t in self.tables.values()])
|
||||
|
||||
def as_cql_query(self):
|
||||
ret = "CREATE KEYSPACE %s WITH replication = %s " % (
|
||||
@@ -605,13 +605,51 @@ class KeyspaceMetadata(object):
|
||||
self.replication_strategy.export_for_schema())
|
||||
return ret + (' AND durable_writes = %s;' % ("true" if self.durable_writes else "false"))
|
||||
|
||||
def user_type_strings(self):
|
||||
user_type_strings = []
|
||||
types = self.user_types.copy()
|
||||
keys = sorted(types.keys())
|
||||
for k in keys:
|
||||
if k in types:
|
||||
self.resolve_user_types(k, types, user_type_strings)
|
||||
return user_type_strings
|
||||
|
||||
def resolve_user_types(self, key, types, user_type_strings):
|
||||
user_type = types.pop(key)
|
||||
for field_type in user_type.field_types:
|
||||
if field_type.cassname == 'UserType' and field_type.typename in types:
|
||||
self.resolve_user_types(field_type.typename, types, user_type_strings)
|
||||
user_type_strings.append(user_type.as_cql_query(formatted=True))
|
||||
|
||||
|
||||
class UserType(object):
|
||||
"""
|
||||
A user defined type, as created by ``CREATE TYPE`` statements.
|
||||
|
||||
User-defined types were introduced in Cassandra 2.1.
|
||||
|
||||
.. versionadded:: 2.1.0
|
||||
"""
|
||||
|
||||
keyspace = None
|
||||
"""
|
||||
The string name of the keyspace in which this type is defined.
|
||||
"""
|
||||
|
||||
name = None
|
||||
"""
|
||||
The name of this type.
|
||||
"""
|
||||
|
||||
field_names = None
|
||||
"""
|
||||
An ordered list of the names for each field in this user-defined type.
|
||||
"""
|
||||
|
||||
field_types = None
|
||||
"""
|
||||
An ordered list of the types for each field in this user-defined type.
|
||||
"""
|
||||
|
||||
def __init__(self, keyspace, name, field_names, field_types):
|
||||
self.keyspace = keyspace
|
||||
@@ -619,6 +657,32 @@ class UserType(object):
|
||||
self.field_names = field_names
|
||||
self.field_types = field_types
|
||||
|
||||
def as_cql_query(self, formatted=False):
|
||||
"""
|
||||
Returns a CQL query that can be used to recreate this type.
|
||||
If `formatted` is set to :const:`True`, extra whitespace will
|
||||
be added to make the query more readable.
|
||||
"""
|
||||
ret = "CREATE TYPE %s.%s (%s" % (
|
||||
protect_name(self.keyspace),
|
||||
protect_name(self.name),
|
||||
"\n" if formatted else "")
|
||||
|
||||
if formatted:
|
||||
field_join = ",\n"
|
||||
padding = " "
|
||||
else:
|
||||
field_join = ", "
|
||||
padding = ""
|
||||
|
||||
fields = []
|
||||
for field_name, field_type in zip(self.field_names, self.field_types):
|
||||
fields.append("%s %s" % (protect_name(field_name), field_type.cql_parameterized_type()))
|
||||
|
||||
ret += field_join.join("%s%s" % (padding, field) for field in fields)
|
||||
ret += "\n);" if formatted else ");"
|
||||
return ret
|
||||
|
||||
|
||||
class TableMetadata(object):
|
||||
"""
|
||||
|
||||
@@ -17,7 +17,6 @@ Connection pooling and host management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from threading import Lock, RLock, Condition
|
||||
@@ -42,13 +41,6 @@ class NoConnectionsAvailable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# example matches:
|
||||
# 1.0.0
|
||||
# 1.0.0-beta1
|
||||
# 2.0-SNAPSHOT
|
||||
version_re = re.compile(r"(?P<major>\d+)\.(?P<minor>\d+)(?:\.(?P<patch>\d+))?(?:-(?P<label>\w+))?")
|
||||
|
||||
|
||||
class Host(object):
|
||||
"""
|
||||
Represents a single Cassandra node.
|
||||
@@ -72,12 +64,6 @@ class Host(object):
|
||||
up or down.
|
||||
"""
|
||||
|
||||
version = None
|
||||
"""
|
||||
A tuple representing the Cassandra version for this host. This will
|
||||
remain as :const:`None` if the version is unknown.
|
||||
"""
|
||||
|
||||
_datacenter = None
|
||||
_rack = None
|
||||
_reconnection_handler = None
|
||||
@@ -115,14 +101,6 @@ class Host(object):
|
||||
self._datacenter = datacenter
|
||||
self._rack = rack
|
||||
|
||||
def set_version(self, version_string):
|
||||
match = version_re.match(version_string)
|
||||
if match is not None:
|
||||
version = [int(match.group('major')), int(match.group('minor')), int(match.group('patch') or 0)]
|
||||
if match.group('label'):
|
||||
version.append(match.group('label'))
|
||||
self.version = tuple(version)
|
||||
|
||||
def set_up(self):
|
||||
if not self.is_up:
|
||||
log.debug("Host %s is now marked up", self.address)
|
||||
|
||||
@@ -27,8 +27,8 @@ import six
|
||||
|
||||
from cassandra import ConsistencyLevel, OperationTimedOut
|
||||
from cassandra.cqltypes import unix_time_from_uuid1
|
||||
from cassandra.encoder import (cql_encoders, cql_encode_object,
|
||||
cql_encode_sequence)
|
||||
from cassandra.encoder import Encoder
|
||||
import cassandra.encoder
|
||||
from cassandra.util import OrderedDict
|
||||
|
||||
import logging
|
||||
@@ -57,7 +57,7 @@ def tuple_factory(colnames, rows):
|
||||
|
||||
Example::
|
||||
|
||||
>>> from cassandra.query import named_tuple_factory
|
||||
>>> from cassandra.query import tuple_factory
|
||||
>>> session = cluster.connect('mykeyspace')
|
||||
>>> session.row_factory = tuple_factory
|
||||
>>> rows = session.execute("SELECT name, age FROM users LIMIT 1")
|
||||
@@ -133,6 +133,9 @@ def ordered_dict_factory(colnames, rows):
|
||||
return [OrderedDict(zip(colnames, row)) for row in rows]
|
||||
|
||||
|
||||
FETCH_SIZE_UNSET = object()
|
||||
|
||||
|
||||
class Statement(object):
|
||||
"""
|
||||
An abstract class representing a single query. There are three subclasses:
|
||||
@@ -160,7 +163,7 @@ class Statement(object):
|
||||
the Session this is executed in will be used.
|
||||
"""
|
||||
|
||||
fetch_size = None
|
||||
fetch_size = FETCH_SIZE_UNSET
|
||||
"""
|
||||
How many rows will be fetched at a time. This overrides the default
|
||||
of :attr:`.Session.default_fetch_size`
|
||||
@@ -175,14 +178,14 @@ class Statement(object):
|
||||
_routing_key = None
|
||||
|
||||
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
|
||||
serial_consistency_level=None, fetch_size=None):
|
||||
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET):
|
||||
self.retry_policy = retry_policy
|
||||
if consistency_level is not None:
|
||||
self.consistency_level = consistency_level
|
||||
if serial_consistency_level is not None:
|
||||
self.serial_consistency_level = serial_consistency_level
|
||||
if fetch_size is not None:
|
||||
self.fetch_size = None
|
||||
if fetch_size is not FETCH_SIZE_UNSET:
|
||||
self.fetch_size = fetch_size
|
||||
self._routing_key = routing_key
|
||||
|
||||
def _get_routing_key(self):
|
||||
@@ -315,9 +318,11 @@ class PreparedStatement(object):
|
||||
|
||||
_protocol_version = None
|
||||
|
||||
fetch_size = FETCH_SIZE_UNSET
|
||||
|
||||
def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace,
|
||||
protocol_version, consistency_level=None, serial_consistency_level=None,
|
||||
fetch_size=None):
|
||||
fetch_size=FETCH_SIZE_UNSET):
|
||||
self.column_metadata = column_metadata
|
||||
self.query_id = query_id
|
||||
self.routing_key_indexes = routing_key_indexes
|
||||
@@ -326,7 +331,8 @@ class PreparedStatement(object):
|
||||
self._protocol_version = protocol_version
|
||||
self.consistency_level = consistency_level
|
||||
self.serial_consistency_level = serial_consistency_level
|
||||
self.fetch_size = fetch_size
|
||||
if fetch_size is not FETCH_SIZE_UNSET:
|
||||
self.fetch_size = fetch_size
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace, protocol_version):
|
||||
@@ -619,8 +625,8 @@ class BatchStatement(Statement):
|
||||
"""
|
||||
if isinstance(statement, six.string_types):
|
||||
if parameters:
|
||||
encoders = cql_encoders if self._session is None else self._session.encoders
|
||||
statement = bind_params(statement, parameters, encoders)
|
||||
encoder = Encoder() if self._session is None else self._session.encoder
|
||||
statement = bind_params(statement, parameters, encoder)
|
||||
self._statements_and_parameters.append((False, statement, ()))
|
||||
elif isinstance(statement, PreparedStatement):
|
||||
query_id = statement.query_id
|
||||
@@ -638,8 +644,8 @@ class BatchStatement(Statement):
|
||||
# it must be a SimpleStatement
|
||||
query_string = statement.query_string
|
||||
if parameters:
|
||||
encoders = cql_encoders if self._session is None else self._session.encoders
|
||||
query_string = bind_params(query_string, parameters, encoders)
|
||||
encoder = Encoder() if self._session is None else self._session.encoder
|
||||
query_string = bind_params(query_string, parameters, encoder)
|
||||
self._statements_and_parameters.append((False, query_string, ()))
|
||||
return self
|
||||
|
||||
@@ -659,33 +665,27 @@ class BatchStatement(Statement):
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
class ValueSequence(object):
|
||||
"""
|
||||
A wrapper class that is used to specify that a sequence of values should
|
||||
be treated as a CQL list of values instead of a single column collection when used
|
||||
as part of the `parameters` argument for :meth:`.Session.execute()`.
|
||||
ValueSequence = cassandra.encoder.ValueSequence
|
||||
"""
|
||||
A wrapper class that is used to specify that a sequence of values should
|
||||
be treated as a CQL list of values instead of a single column collection when used
|
||||
as part of the `parameters` argument for :meth:`.Session.execute()`.
|
||||
|
||||
This is typically needed when supplying a list of keys to select.
|
||||
For example::
|
||||
This is typically needed when supplying a list of keys to select.
|
||||
For example::
|
||||
|
||||
>>> my_user_ids = ('alice', 'bob', 'charles')
|
||||
>>> query = "SELECT * FROM users WHERE user_id IN %s"
|
||||
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
|
||||
>>> my_user_ids = ('alice', 'bob', 'charles')
|
||||
>>> query = "SELECT * FROM users WHERE user_id IN %s"
|
||||
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, sequence):
|
||||
self.sequence = sequence
|
||||
|
||||
def __str__(self):
|
||||
return cql_encode_sequence(self.sequence)
|
||||
"""
|
||||
|
||||
|
||||
def bind_params(query, params, encoders):
|
||||
def bind_params(query, params, encoder):
|
||||
if isinstance(params, dict):
|
||||
return query % dict((k, encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
|
||||
return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params))
|
||||
else:
|
||||
return query % tuple(encoders.get(type(v), cql_encode_object)(v) for v in params)
|
||||
return query % tuple(encoder.cql_encode_all_types(v) for v in params)
|
||||
|
||||
|
||||
class TraceUnavailable(Exception):
|
||||
|
||||
@@ -43,6 +43,8 @@
|
||||
|
||||
.. automethod:: shutdown
|
||||
|
||||
.. automethod:: register_user_type
|
||||
|
||||
.. automethod:: register_listener
|
||||
|
||||
.. automethod:: unregister_listener
|
||||
@@ -63,6 +65,10 @@
|
||||
|
||||
.. autoattribute:: default_fetch_size
|
||||
|
||||
.. autoattribute:: use_client_timestamp
|
||||
|
||||
.. autoattribute:: encoder
|
||||
|
||||
.. automethod:: execute(statement[, parameters][, timeout][, trace])
|
||||
|
||||
.. automethod:: execute_async(statement[, parameters][, trace])
|
||||
@@ -98,3 +104,5 @@
|
||||
|
||||
.. autoexception:: NoHostAvailable ()
|
||||
:members:
|
||||
|
||||
.. autoexception:: UserTypeDoesNotExist ()
|
||||
|
||||
36
docs/api/cassandra/encoder.rst
Normal file
36
docs/api/cassandra/encoder.rst
Normal file
@@ -0,0 +1,36 @@
|
||||
``cassandra.encoder`` - Encoders for non-prepared Statements
|
||||
============================================================
|
||||
|
||||
.. module:: cassandra.encoder
|
||||
|
||||
.. autoclass:: Encoder ()
|
||||
|
||||
.. autoattribute:: cassandra.encoder.Encoder.mapping
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_none ()
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_object ()
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_all_types ()
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_sequence ()
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_str ()
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_unicode ()
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_bytes ()
|
||||
|
||||
Converts strings, buffers, and bytearrays into CQL blob literals.
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_datetime ()
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_date ()
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_map_collection ()
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_list_collection ()
|
||||
|
||||
.. automethod:: cassandra.encoder.Encoder.cql_encode_set_collection ()
|
||||
|
||||
.. automethod:: cql_encode_tuple ()
|
||||
@@ -2,7 +2,6 @@
|
||||
===========================================
|
||||
|
||||
.. module:: cassandra.metrics
|
||||
:members:
|
||||
|
||||
.. autoclass:: cassandra.metrics.Metrics ()
|
||||
:members:
|
||||
|
||||
@@ -34,8 +34,18 @@
|
||||
|
||||
.. autoattribute:: COUNTER
|
||||
|
||||
.. autoclass:: ValueSequence
|
||||
:members:
|
||||
.. autoclass:: cassandra.query.ValueSequence
|
||||
|
||||
A wrapper class that is used to specify that a sequence of values should
|
||||
be treated as a CQL list of values instead of a single column collection when used
|
||||
as part of the `parameters` argument for :meth:`.Session.execute()`.
|
||||
|
||||
This is typically needed when supplying a list of keys to select.
|
||||
For example::
|
||||
|
||||
>>> my_user_ids = ('alice', 'bob', 'charles')
|
||||
>>> query = "SELECT * FROM users WHERE user_id IN %s"
|
||||
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
|
||||
|
||||
.. autoclass:: QueryTrace ()
|
||||
:members:
|
||||
|
||||
@@ -12,6 +12,7 @@ API Documentation
|
||||
cassandra/metrics
|
||||
cassandra/query
|
||||
cassandra/pool
|
||||
cassandra/encoder
|
||||
cassandra/decoder
|
||||
cassandra/concurrent
|
||||
cassandra/connection
|
||||
|
||||
@@ -137,7 +137,7 @@ when you execute:
|
||||
"""
|
||||
INSERT INTO users (name, credits, user_id)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
""",
|
||||
("John O'Reilly", 42, uuid.uuid1())
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ MULTIDC_CLUSTER_NAME = 'multidc_test_cluster'
|
||||
CCM_CLUSTER = None
|
||||
|
||||
CASSANDRA_DIR = os.getenv('CASSANDRA_DIR', None)
|
||||
CASSANDRA_HOME = os.getenv('CASSANDRA_HOME', None)
|
||||
CASSANDRA_VERSION = os.getenv('CASSANDRA_VERSION', '2.0.9')
|
||||
|
||||
if CASSANDRA_VERSION.startswith('1'):
|
||||
@@ -97,6 +98,8 @@ def get_node(node_id):
|
||||
def setup_package():
|
||||
if CASSANDRA_DIR:
|
||||
log.info("Using Cassandra dir: %s", CASSANDRA_DIR)
|
||||
elif CASSANDRA_HOME:
|
||||
log.info("Using Cassandra home: %s", CASSANDRA_HOME)
|
||||
else:
|
||||
log.info('Using Cassandra version: %s', CASSANDRA_VERSION)
|
||||
try:
|
||||
|
||||
139
tests/integration/datatype_utils.py
Normal file
139
tests/integration/datatype_utils.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Copyright 2013-2014 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from decimal import Decimal
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
import pytz
|
||||
|
||||
try:
|
||||
from blist import sortedset
|
||||
except ImportError:
|
||||
sortedset = set # noqa
|
||||
|
||||
DATA_TYPE_PRIMITIVES = [
|
||||
'ascii',
|
||||
'bigint',
|
||||
'blob',
|
||||
'boolean',
|
||||
# 'counter', counters are not allowed inside tuples
|
||||
'decimal',
|
||||
'double',
|
||||
'float',
|
||||
'inet',
|
||||
'int',
|
||||
'text',
|
||||
'timestamp',
|
||||
'timeuuid',
|
||||
'uuid',
|
||||
'varchar',
|
||||
'varint',
|
||||
]
|
||||
|
||||
DATA_TYPE_NON_PRIMITIVE_NAMES = [
|
||||
'list',
|
||||
'set',
|
||||
'map',
|
||||
'tuple'
|
||||
]
|
||||
|
||||
|
||||
def get_sample_data():
|
||||
"""
|
||||
Create a standard set of sample inputs for testing.
|
||||
"""
|
||||
|
||||
sample_data = {}
|
||||
|
||||
for datatype in DATA_TYPE_PRIMITIVES:
|
||||
if datatype == 'ascii':
|
||||
sample_data[datatype] = 'ascii'
|
||||
|
||||
elif datatype == 'bigint':
|
||||
sample_data[datatype] = 2 ** 63 - 1
|
||||
|
||||
elif datatype == 'blob':
|
||||
sample_data[datatype] = bytearray(b'hello world')
|
||||
|
||||
elif datatype == 'boolean':
|
||||
sample_data[datatype] = True
|
||||
|
||||
elif datatype == 'counter':
|
||||
# Not supported in an insert statement
|
||||
pass
|
||||
|
||||
elif datatype == 'decimal':
|
||||
sample_data[datatype] = Decimal('12.3E+7')
|
||||
|
||||
elif datatype == 'double':
|
||||
sample_data[datatype] = 1.23E+8
|
||||
|
||||
elif datatype == 'float':
|
||||
sample_data[datatype] = 3.4028234663852886e+38
|
||||
|
||||
elif datatype == 'inet':
|
||||
sample_data[datatype] = '123.123.123.123'
|
||||
|
||||
elif datatype == 'int':
|
||||
sample_data[datatype] = 2147483647
|
||||
|
||||
elif datatype == 'text':
|
||||
sample_data[datatype] = 'text'
|
||||
|
||||
elif datatype == 'timestamp':
|
||||
sample_data[datatype] = datetime.datetime.fromtimestamp(872835240, tz=pytz.timezone('America/New_York')).astimezone(pytz.UTC).replace(tzinfo=None)
|
||||
|
||||
elif datatype == 'timeuuid':
|
||||
sample_data[datatype] = UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66')
|
||||
|
||||
elif datatype == 'uuid':
|
||||
sample_data[datatype] = UUID('067e6162-3b6f-4ae2-a171-2470b63dff00')
|
||||
|
||||
elif datatype == 'varchar':
|
||||
sample_data[datatype] = 'varchar'
|
||||
|
||||
elif datatype == 'varint':
|
||||
sample_data[datatype] = int(str(2147483647) + '000')
|
||||
|
||||
else:
|
||||
raise Exception('Missing handling of %s.' % datatype)
|
||||
|
||||
return sample_data
|
||||
|
||||
SAMPLE_DATA = get_sample_data()
|
||||
|
||||
def get_sample(datatype):
|
||||
"""
|
||||
Helper method to access created sample data
|
||||
"""
|
||||
|
||||
return SAMPLE_DATA[datatype]
|
||||
|
||||
def get_nonprim_sample(non_prim_type, datatype):
|
||||
"""
|
||||
Helper method to access created sample data for non-primitives
|
||||
"""
|
||||
|
||||
if non_prim_type == 'list':
|
||||
return [get_sample(datatype), get_sample(datatype)]
|
||||
elif non_prim_type == 'set':
|
||||
return sortedset([get_sample(datatype)])
|
||||
elif non_prim_type == 'map':
|
||||
if datatype == 'blob':
|
||||
return {get_sample('ascii'): get_sample(datatype)}
|
||||
else:
|
||||
return {get_sample(datatype): get_sample(datatype)}
|
||||
elif non_prim_type == 'tuple':
|
||||
return (get_sample(datatype),)
|
||||
else:
|
||||
raise Exception('Missing handling of non-primitive type {0}.'.format(non_prim_type))
|
||||
@@ -18,9 +18,8 @@ except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
|
||||
SimpleStatement, BatchStatement, BatchType,
|
||||
dict_factory)
|
||||
from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement,
|
||||
BatchStatement, BatchType, dict_factory)
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.policies import HostDistance
|
||||
|
||||
@@ -45,14 +44,6 @@ class QueryTest(unittest.TestCase):
|
||||
session.execute(bound)
|
||||
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
|
||||
|
||||
def test_value_sequence(self):
|
||||
"""
|
||||
Test the output of ValueSequences()
|
||||
"""
|
||||
|
||||
my_user_ids = ('alice', 'bob', 'charles')
|
||||
self.assertEqual(str(ValueSequence(my_user_ids)), "( 'alice' , 'bob' , 'charles' )")
|
||||
|
||||
def test_trace_prints_okay(self):
|
||||
"""
|
||||
Code coverage to ensure trace prints to string without error
|
||||
|
||||
@@ -26,7 +26,7 @@ from itertools import cycle, count
|
||||
from six.moves import range
|
||||
from threading import Event
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.cluster import Cluster, PagedResult
|
||||
from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args
|
||||
from cassandra.policies import HostDistance
|
||||
from cassandra.query import SimpleStatement
|
||||
@@ -281,3 +281,79 @@ class QueryPagingTests(unittest.TestCase):
|
||||
for (success, result) in results:
|
||||
self.assertTrue(success)
|
||||
self.assertEquals(100, len(list(result)))
|
||||
|
||||
def test_fetch_size(self):
|
||||
"""
|
||||
Ensure per-statement fetch_sizes override the default fetch size.
|
||||
"""
|
||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||
[(i, ) for i in range(100)])
|
||||
execute_concurrent(self.session, list(statements_and_params))
|
||||
|
||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||
|
||||
self.session.default_fetch_size = 10
|
||||
result = self.session.execute(prepared, [])
|
||||
self.assertIsInstance(result, PagedResult)
|
||||
|
||||
self.session.default_fetch_size = 2000
|
||||
result = self.session.execute(prepared, [])
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
self.session.default_fetch_size = None
|
||||
result = self.session.execute(prepared, [])
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
self.session.default_fetch_size = 10
|
||||
|
||||
prepared.fetch_size = 2000
|
||||
result = self.session.execute(prepared, [])
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
prepared.fetch_size = None
|
||||
result = self.session.execute(prepared, [])
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
prepared.fetch_size = 10
|
||||
result = self.session.execute(prepared, [])
|
||||
self.assertIsInstance(result, PagedResult)
|
||||
|
||||
prepared.fetch_size = 2000
|
||||
bound = prepared.bind([])
|
||||
result = self.session.execute(bound, [])
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
prepared.fetch_size = None
|
||||
bound = prepared.bind([])
|
||||
result = self.session.execute(bound, [])
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
prepared.fetch_size = 10
|
||||
bound = prepared.bind([])
|
||||
result = self.session.execute(bound, [])
|
||||
self.assertIsInstance(result, PagedResult)
|
||||
|
||||
bound.fetch_size = 2000
|
||||
result = self.session.execute(bound, [])
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
bound.fetch_size = None
|
||||
result = self.session.execute(bound, [])
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
bound.fetch_size = 10
|
||||
result = self.session.execute(bound, [])
|
||||
self.assertIsInstance(result, PagedResult)
|
||||
|
||||
s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None)
|
||||
result = self.session.execute(s, [])
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
s = SimpleStatement("SELECT * FROM test3rf.test")
|
||||
result = self.session.execute(s, [])
|
||||
self.assertIsInstance(result, PagedResult)
|
||||
|
||||
s = SimpleStatement("SELECT * FROM test3rf.test")
|
||||
s.fetch_size = None
|
||||
result = self.session.execute(s, [])
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from tests.integration.datatype_utils import get_sample, DATA_TYPE_PRIMITIVES, DATA_TYPE_NON_PRIMITIVE_NAMES
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
@@ -33,7 +34,6 @@ except ImportError:
|
||||
from cassandra import InvalidRequest
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.cqltypes import Int32Type, EMPTY
|
||||
from cassandra.encoder import cql_encode_tuple
|
||||
from cassandra.query import dict_factory
|
||||
from cassandra.util import OrderedDict
|
||||
|
||||
@@ -87,8 +87,8 @@ class TypeTests(unittest.TestCase):
|
||||
|
||||
s.execute(query, params)
|
||||
expected_vals = [
|
||||
'key1',
|
||||
bytearray(b'blobyblob')
|
||||
'key1',
|
||||
bytearray(b'blobyblob')
|
||||
]
|
||||
|
||||
results = s.execute("SELECT * FROM mytable")
|
||||
@@ -404,12 +404,18 @@ class TypeTests(unittest.TestCase):
|
||||
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
|
||||
|
||||
def test_tuple_type(self):
|
||||
"""
|
||||
Basic test of tuple functionality
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
s.encoders[tuple] = cql_encode_tuple
|
||||
|
||||
# use this encoder in order to insert tuples
|
||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||
|
||||
s.execute("""CREATE KEYSPACE test_tuple_type
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
|
||||
@@ -428,14 +434,241 @@ class TypeTests(unittest.TestCase):
|
||||
result = s.execute("SELECT b FROM mytable WHERE a=1")[0]
|
||||
self.assertEqual(partial_result, result.b)
|
||||
|
||||
# test single value tuples
|
||||
subpartial = ('zoo',)
|
||||
subpartial_result = subpartial + (None, None)
|
||||
s.execute("INSERT INTO mytable (a, b) VALUES (2, %s)", parameters=(subpartial,))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a=2")[0]
|
||||
self.assertEqual(subpartial_result, result.b)
|
||||
|
||||
# test prepared statement
|
||||
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
|
||||
s.execute(prepared, parameters=(2, complete))
|
||||
s.execute(prepared, parameters=(3, partial))
|
||||
s.execute(prepared, parameters=(3, complete))
|
||||
s.execute(prepared, parameters=(4, partial))
|
||||
s.execute(prepared, parameters=(5, subpartial))
|
||||
|
||||
# extra items in the tuple should result in an error
|
||||
self.assertRaises(ValueError, s.execute, prepared, parameters=(0, (1, 2, 3, 4, 5, 6)))
|
||||
|
||||
prepared = s.prepare("SELECT b FROM mytable WHERE a=?")
|
||||
self.assertEqual(complete, s.execute(prepared, (2,))[0].b)
|
||||
self.assertEqual(partial_result, s.execute(prepared, (3,))[0].b)
|
||||
self.assertEqual(complete, s.execute(prepared, (3,))[0].b)
|
||||
self.assertEqual(partial_result, s.execute(prepared, (4,))[0].b)
|
||||
self.assertEqual(subpartial_result, s.execute(prepared, (5,))[0].b)
|
||||
|
||||
def test_tuple_type_varying_lengths(self):
|
||||
"""
|
||||
Test tuple types of lengths of 1, 2, 3, and 384 to ensure edge cases work
|
||||
as expected.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
# set the row_factory to dict_factory for programmatic access
|
||||
# set the encoder for tuples for the ability to write tuples
|
||||
s.row_factory = dict_factory
|
||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||
|
||||
s.execute("""CREATE KEYSPACE test_tuple_type_varying_lengths
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
|
||||
s.set_keyspace("test_tuple_type_varying_lengths")
|
||||
|
||||
# programmatically create the table with tuples of said sizes
|
||||
lengths = (1, 2, 3, 384)
|
||||
value_schema = []
|
||||
for i in lengths:
|
||||
value_schema += [' v_%s tuple<%s>' % (i, ', '.join(['int'] * i))]
|
||||
s.execute("CREATE TABLE mytable (k int PRIMARY KEY, %s)" % (', '.join(value_schema),))
|
||||
|
||||
# insert tuples into same key using different columns
|
||||
# and verify the results
|
||||
for i in lengths:
|
||||
created_tuple = tuple(range(0, i))
|
||||
|
||||
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
|
||||
|
||||
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
|
||||
self.assertEqual(tuple(created_tuple), result['v_%s' % i])
|
||||
|
||||
def test_tuple_primitive_subtypes(self):
|
||||
"""
|
||||
Ensure tuple subtypes are appropriately handled.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||
|
||||
s.execute("""CREATE KEYSPACE test_tuple_primitive_subtypes
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
|
||||
s.set_keyspace("test_tuple_primitive_subtypes")
|
||||
|
||||
s.execute("CREATE TABLE mytable ("
|
||||
"k int PRIMARY KEY, "
|
||||
"v tuple<%s>)" % ','.join(DATA_TYPE_PRIMITIVES))
|
||||
|
||||
for i in range(len(DATA_TYPE_PRIMITIVES)):
|
||||
# create tuples to be written and ensure they match with the expected response
|
||||
# responses have trailing None values for every element that has not been written
|
||||
created_tuple = [get_sample(DATA_TYPE_PRIMITIVES[j]) for j in range(i + 1)]
|
||||
response_tuple = tuple(created_tuple + [None for j in range(len(DATA_TYPE_PRIMITIVES) - i - 1)])
|
||||
written_tuple = tuple(created_tuple)
|
||||
|
||||
s.execute("INSERT INTO mytable (k, v) VALUES (%s, %s)", (i, written_tuple))
|
||||
|
||||
result = s.execute("SELECT v FROM mytable WHERE k=%s", (i,))[0]
|
||||
self.assertEqual(response_tuple, result.v)
|
||||
|
||||
def test_tuple_non_primitive_subtypes(self):
|
||||
"""
|
||||
Ensure tuple subtypes are appropriately handled for maps, sets, and lists.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
# set the row_factory to dict_factory for programmatic access
|
||||
# set the encoder for tuples for the ability to write tuples
|
||||
s.row_factory = dict_factory
|
||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||
|
||||
s.execute("""CREATE KEYSPACE test_tuple_non_primitive_subtypes
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
|
||||
s.set_keyspace("test_tuple_non_primitive_subtypes")
|
||||
|
||||
values = []
|
||||
|
||||
# create list values
|
||||
for datatype in DATA_TYPE_PRIMITIVES:
|
||||
values.append('v_{} tuple<list<{}>>'.format(len(values), datatype))
|
||||
|
||||
# create set values
|
||||
for datatype in DATA_TYPE_PRIMITIVES:
|
||||
values.append('v_{} tuple<set<{}>>'.format(len(values), datatype))
|
||||
|
||||
# create map values
|
||||
for datatype in DATA_TYPE_PRIMITIVES:
|
||||
datatype_1 = datatype_2 = datatype
|
||||
if datatype == 'blob':
|
||||
# unhashable type: 'bytearray'
|
||||
datatype_1 = 'ascii'
|
||||
values.append('v_{} tuple<map<{}, {}>>'.format(len(values), datatype_1, datatype_2))
|
||||
|
||||
# make sure we're testing all non primitive data types in the future
|
||||
if set(DATA_TYPE_NON_PRIMITIVE_NAMES) != set(['tuple', 'list', 'map', 'set']):
|
||||
raise NotImplemented('Missing datatype not implemented: {}'.format(
|
||||
set(DATA_TYPE_NON_PRIMITIVE_NAMES) - set(['tuple', 'list', 'map', 'set'])
|
||||
))
|
||||
|
||||
# create table
|
||||
s.execute("CREATE TABLE mytable ("
|
||||
"k int PRIMARY KEY, "
|
||||
"%s)" % ', '.join(values))
|
||||
|
||||
i = 0
|
||||
# test tuple<list<datatype>>
|
||||
for datatype in DATA_TYPE_PRIMITIVES:
|
||||
created_tuple = tuple([[get_sample(datatype)]])
|
||||
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
|
||||
|
||||
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
|
||||
self.assertEqual(created_tuple, result['v_%s' % i])
|
||||
i += 1
|
||||
|
||||
# test tuple<set<datatype>>
|
||||
for datatype in DATA_TYPE_PRIMITIVES:
|
||||
created_tuple = tuple([sortedset([get_sample(datatype)])])
|
||||
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
|
||||
|
||||
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
|
||||
self.assertEqual(created_tuple, result['v_%s' % i])
|
||||
i += 1
|
||||
|
||||
# test tuple<map<datatype, datatype>>
|
||||
for datatype in DATA_TYPE_PRIMITIVES:
|
||||
if datatype == 'blob':
|
||||
# unhashable type: 'bytearray'
|
||||
created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}])
|
||||
else:
|
||||
created_tuple = tuple([{get_sample(datatype): get_sample(datatype)}])
|
||||
|
||||
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
|
||||
|
||||
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
|
||||
self.assertEqual(created_tuple, result['v_%s' % i])
|
||||
i += 1
|
||||
|
||||
def nested_tuples_schema_helper(self, depth):
|
||||
"""
|
||||
Helper method for creating nested tuple schema
|
||||
"""
|
||||
|
||||
if depth == 0:
|
||||
return 'int'
|
||||
else:
|
||||
return 'tuple<%s>' % self.nested_tuples_schema_helper(depth - 1)
|
||||
|
||||
def nested_tuples_creator_helper(self, depth):
|
||||
"""
|
||||
Helper method for creating nested tuples
|
||||
"""
|
||||
|
||||
if depth == 0:
|
||||
return 303
|
||||
else:
|
||||
return (self.nested_tuples_creator_helper(depth - 1), )
|
||||
|
||||
def test_nested_tuples(self):
|
||||
"""
|
||||
Ensure nested are appropriately handled.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
# set the row_factory to dict_factory for programmatic access
|
||||
# set the encoder for tuples for the ability to write tuples
|
||||
s.row_factory = dict_factory
|
||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||
|
||||
s.execute("""CREATE KEYSPACE test_nested_tuples
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
|
||||
s.set_keyspace("test_nested_tuples")
|
||||
|
||||
# create a table with multiple sizes of nested tuples
|
||||
s.execute("CREATE TABLE mytable ("
|
||||
"k int PRIMARY KEY, "
|
||||
"v_1 %s,"
|
||||
"v_2 %s,"
|
||||
"v_3 %s,"
|
||||
"v_128 %s"
|
||||
")" % (self.nested_tuples_schema_helper(1),
|
||||
self.nested_tuples_schema_helper(2),
|
||||
self.nested_tuples_schema_helper(3),
|
||||
self.nested_tuples_schema_helper(128)))
|
||||
|
||||
for i in (1, 2, 3, 128):
|
||||
# create tuple
|
||||
created_tuple = self.nested_tuples_creator_helper(i)
|
||||
|
||||
# write tuple
|
||||
s.execute("INSERT INTO mytable (k, v_%s) VALUES (%s, %s)", (i, i, created_tuple))
|
||||
|
||||
# verify tuple was written and read correctly
|
||||
result = s.execute("SELECT v_%s FROM mytable WHERE k=%s", (i, i))[0]
|
||||
self.assertEqual(created_tuple, result['v_%s' % i])
|
||||
|
||||
def test_unicode_query_string(self):
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from cassandra.query import dict_factory
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
@@ -22,9 +23,11 @@ log = logging.getLogger(__name__)
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.cluster import Cluster, UserTypeDoesNotExist
|
||||
|
||||
from tests.integration import get_server_versions, PROTOCOL_VERSION
|
||||
from tests.integration.datatype_utils import get_sample, get_nonprim_sample,\
|
||||
DATA_TYPE_PRIMITIVES, DATA_TYPE_NON_PRIMITIVE_NAMES
|
||||
|
||||
|
||||
class TypeTests(unittest.TestCase):
|
||||
@@ -227,3 +230,257 @@ class TypeTests(unittest.TestCase):
|
||||
self.assertTrue(type(row.b) is User)
|
||||
|
||||
c.shutdown()
|
||||
|
||||
def test_udt_sizes(self):
|
||||
"""
|
||||
Test for ensuring extra-lengthy udts are handled correctly.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
MAX_TEST_LENGTH = 16384
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
s.execute("""CREATE KEYSPACE test_udt_sizes
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
|
||||
s.set_keyspace("test_udt_sizes")
|
||||
|
||||
# create the seed udt
|
||||
s.execute("CREATE TYPE lengthy_udt ({})".format(', '.join(['v_{} int'.format(i) for i in range(MAX_TEST_LENGTH)])))
|
||||
|
||||
# create a table with multiple sizes of nested udts
|
||||
# no need for all nested types, only a spot checked few and the largest one
|
||||
s.execute("CREATE TABLE mytable ("
|
||||
"k int PRIMARY KEY, "
|
||||
"v lengthy_udt)")
|
||||
|
||||
# create and register the seed udt type
|
||||
udt = namedtuple('lengthy_udt', tuple(['v_{}'.format(i) for i in range(MAX_TEST_LENGTH)]))
|
||||
c.register_user_type("test_udt_sizes", "lengthy_udt", udt)
|
||||
|
||||
# verify inserts and reads
|
||||
for i in (0, 1, 2, 3, MAX_TEST_LENGTH):
|
||||
# create udt
|
||||
params = [j for j in range(i)] + [None for j in range(MAX_TEST_LENGTH - i)]
|
||||
created_udt = udt(*params)
|
||||
|
||||
# write udt
|
||||
s.execute("INSERT INTO mytable (k, v) VALUES (0, %s)", (created_udt,))
|
||||
|
||||
# verify udt was written and read correctly
|
||||
result = s.execute("SELECT v FROM mytable WHERE k=0")[0]
|
||||
self.assertEqual(created_udt, result.v)
|
||||
|
||||
def nested_udt_helper(self, udts, i):
|
||||
"""
|
||||
Helper for creating nested udts.
|
||||
"""
|
||||
|
||||
if i == 0:
|
||||
return udts[0](42, 'Bob')
|
||||
else:
|
||||
return udts[i](self.nested_udt_helper(udts, i - 1))
|
||||
|
||||
def test_nested_registered_udts(self):
|
||||
"""
|
||||
Test for ensuring nested udts are handled correctly.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
MAX_NESTING_DEPTH = 4 # TODO: Move to 128, or similar
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
# set the row_factory to dict_factory for programmatically accessing values
|
||||
s.row_factory = dict_factory
|
||||
|
||||
s.execute("""CREATE KEYSPACE test_nested_unregistered_udts
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
|
||||
s.set_keyspace("test_nested_unregistered_udts")
|
||||
|
||||
# create the seed udt
|
||||
s.execute("CREATE TYPE depth_0 (age int, name text)")
|
||||
|
||||
# create the nested udts
|
||||
for i in range(MAX_NESTING_DEPTH):
|
||||
s.execute("CREATE TYPE depth_{} (value depth_{})".format(i + 1, i))
|
||||
|
||||
# create a table with multiple sizes of nested udts
|
||||
# no need for all nested types, only a spot checked few and the largest one
|
||||
s.execute("CREATE TABLE mytable ("
|
||||
"k int PRIMARY KEY, "
|
||||
"v_0 depth_0, "
|
||||
"v_1 depth_1, "
|
||||
"v_2 depth_2, "
|
||||
"v_3 depth_3, "
|
||||
"v_{0} depth_{0})".format(MAX_NESTING_DEPTH))
|
||||
|
||||
# create the udt container
|
||||
udts = []
|
||||
|
||||
# create and register the seed udt type
|
||||
udt = namedtuple('depth_0', ('age', 'name'))
|
||||
udts.append(udt)
|
||||
c.register_user_type("test_nested_unregistered_udts", "depth_0", udts[0])
|
||||
|
||||
# create and register the nested udt types
|
||||
for i in range(MAX_NESTING_DEPTH):
|
||||
udt = namedtuple('depth_{}'.format(i + 1), ('value'))
|
||||
udts.append(udt)
|
||||
c.register_user_type("test_nested_unregistered_udts", "depth_{}".format(i + 1), udts[i + 1])
|
||||
|
||||
# verify inserts and reads
|
||||
for i in (0, 1, 2, 3, MAX_NESTING_DEPTH):
|
||||
# create udt
|
||||
udt = self.nested_udt_helper(udts, i)
|
||||
|
||||
# write udt
|
||||
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, udt))
|
||||
|
||||
# verify udt was written and read correctly
|
||||
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
|
||||
self.assertEqual(udt, result['v_%s' % i])
|
||||
|
||||
def test_nested_unregistered_udts(self):
|
||||
"""
|
||||
Test for ensuring nested unregistered udts are handled correctly.
|
||||
"""
|
||||
|
||||
# close copy to test_nested_registered_udts
|
||||
pass
|
||||
|
||||
def test_nested_registered_udts_with_different_namedtuples(self):
|
||||
"""
|
||||
Test for ensuring nested udts are handled correctly when the
|
||||
created namedtuples are use names that are different the cql type.
|
||||
"""
|
||||
|
||||
# close copy to test_nested_registered_udts
|
||||
pass
|
||||
|
||||
def test_non_existing_types(self):
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
c.connect()
|
||||
User = namedtuple('user', ('age', 'name'))
|
||||
self.assertRaises(UserTypeDoesNotExist, c.register_user_type, "some_bad_keyspace", "user", User)
|
||||
self.assertRaises(UserTypeDoesNotExist, c.register_user_type, "system", "user", User)
|
||||
|
||||
def test_primitive_datatypes(self):
|
||||
"""
|
||||
Test for inserting various types of DATA_TYPE_PRIMITIVES into UDT's
|
||||
"""
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
# create keyspace
|
||||
s.execute("""
|
||||
CREATE KEYSPACE test_primitive_datatypes
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
|
||||
""")
|
||||
s.set_keyspace("test_primitive_datatypes")
|
||||
|
||||
# create UDT
|
||||
alpha_type_list = []
|
||||
start_index = ord('a')
|
||||
for i, datatype in enumerate(DATA_TYPE_PRIMITIVES):
|
||||
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
|
||||
|
||||
s.execute("""
|
||||
CREATE TYPE alldatatypes ({0})
|
||||
""".format(', '.join(alpha_type_list))
|
||||
)
|
||||
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b alldatatypes)")
|
||||
|
||||
# register UDT
|
||||
alphabet_list = []
|
||||
for i in range(ord('a'), ord('a') + len(DATA_TYPE_PRIMITIVES)):
|
||||
alphabet_list.append('{}'.format(chr(i)))
|
||||
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
|
||||
c.register_user_type("test_primitive_datatypes", "alldatatypes", Alldatatypes)
|
||||
|
||||
# insert UDT data
|
||||
params = []
|
||||
for datatype in DATA_TYPE_PRIMITIVES:
|
||||
params.append((get_sample(datatype)))
|
||||
|
||||
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
|
||||
s.execute(insert, (0, Alldatatypes(*params)))
|
||||
|
||||
# retrieve and verify data
|
||||
results = s.execute("SELECT * FROM mytable")
|
||||
self.assertEqual(1, len(results))
|
||||
|
||||
row = results[0].b
|
||||
for expected, actual in zip(params, row):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
c.shutdown()
|
||||
|
||||
def test_nonprimitive_datatypes(self):
|
||||
"""
|
||||
Test for inserting various types of DATA_TYPE_NON_PRIMITIVE into UDT's
|
||||
"""
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
# create keyspace
|
||||
s.execute("""
|
||||
CREATE KEYSPACE test_nonprimitive_datatypes
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
|
||||
""")
|
||||
s.set_keyspace("test_nonprimitive_datatypes")
|
||||
|
||||
# create UDT
|
||||
alpha_type_list = []
|
||||
start_index = ord('a')
|
||||
for i, nonprim_datatype in enumerate(DATA_TYPE_NON_PRIMITIVE_NAMES):
|
||||
for j, datatype in enumerate(DATA_TYPE_PRIMITIVES):
|
||||
if nonprim_datatype == "map":
|
||||
type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j),
|
||||
nonprim_datatype, datatype)
|
||||
else:
|
||||
type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j),
|
||||
nonprim_datatype, datatype)
|
||||
alpha_type_list.append(type_string)
|
||||
|
||||
s.execute("""
|
||||
CREATE TYPE alldatatypes ({0})
|
||||
""".format(', '.join(alpha_type_list))
|
||||
)
|
||||
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b alldatatypes)")
|
||||
|
||||
# register UDT
|
||||
alphabet_list = []
|
||||
for i in range(ord('a'), ord('a') + len(DATA_TYPE_NON_PRIMITIVE_NAMES)):
|
||||
for j in range(ord('a'), ord('a') + len(DATA_TYPE_PRIMITIVES)):
|
||||
alphabet_list.append('{0}_{1}'.format(chr(i), chr(j)))
|
||||
|
||||
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
|
||||
c.register_user_type("test_nonprimitive_datatypes", "alldatatypes", Alldatatypes)
|
||||
|
||||
# insert UDT data
|
||||
params = []
|
||||
for nonprim_datatype in DATA_TYPE_NON_PRIMITIVE_NAMES:
|
||||
for datatype in DATA_TYPE_PRIMITIVES:
|
||||
params.append((get_nonprim_sample(nonprim_datatype, datatype)))
|
||||
|
||||
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
|
||||
s.execute(insert, (0, Alldatatypes(*params)))
|
||||
|
||||
# retrieve and verify data
|
||||
results = s.execute("SELECT * FROM mytable")
|
||||
self.assertEqual(1, len(results))
|
||||
|
||||
row = results[0].b
|
||||
for expected, actual in zip(params, row):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
c.shutdown()
|
||||
|
||||
@@ -240,21 +240,3 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
self.assertEqual(a, b, 'Two Host instances should be equal when sharing.')
|
||||
self.assertNotEqual(a, c, 'Two Host instances should NOT be equal when using two different addresses.')
|
||||
self.assertNotEqual(b, c, 'Two Host instances should NOT be equal when using two different addresses.')
|
||||
|
||||
|
||||
class HostTests(unittest.TestCase):
|
||||
|
||||
def test_version_parsing(self):
|
||||
host = Host('127.0.0.1', SimpleConvictionPolicy)
|
||||
|
||||
host.set_version("1.0.0")
|
||||
self.assertEqual((1, 0, 0), host.version)
|
||||
|
||||
host.set_version("1.0")
|
||||
self.assertEqual((1, 0, 0), host.version)
|
||||
|
||||
host.set_version("1.0.0-beta1")
|
||||
self.assertEqual((1, 0, 0, 'beta1'), host.version)
|
||||
|
||||
host.set_version("1.0-SNAPSHOT")
|
||||
self.assertEqual((1, 0, 0, 'SNAPSHOT'), host.version)
|
||||
|
||||
@@ -15,14 +15,18 @@
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
import unittest # noqa
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import cassandra
|
||||
from cassandra.cqltypes import IntegerType, AsciiType, TupleType
|
||||
from cassandra.metadata import (Murmur3Token, MD5Token,
|
||||
BytesToken, ReplicationStrategy,
|
||||
NetworkTopologyStrategy, SimpleStrategy,
|
||||
LocalStrategy, NoMurmur3, protect_name,
|
||||
protect_names, protect_value, is_valid_name)
|
||||
protect_names, protect_value, is_valid_name,
|
||||
UserType, KeyspaceMetadata)
|
||||
from cassandra.policies import SimpleConvictionPolicy
|
||||
from cassandra.pool import Host
|
||||
|
||||
@@ -181,12 +185,12 @@ class TestNameEscaping(unittest.TestCase):
|
||||
'tests ?!@#$%^&*()',
|
||||
'1'
|
||||
]),
|
||||
[
|
||||
'tests',
|
||||
"\"test's\"",
|
||||
'"tests ?!@#$%^&*()"',
|
||||
'"1"'
|
||||
])
|
||||
[
|
||||
'tests',
|
||||
"\"test's\"",
|
||||
'"tests ?!@#$%^&*()"',
|
||||
'"1"'
|
||||
])
|
||||
|
||||
def test_protect_value(self):
|
||||
"""
|
||||
@@ -248,3 +252,63 @@ class TestTokens(unittest.TestCase):
|
||||
self.fail('Tokens for ByteOrderedPartitioner should be only strings')
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
|
||||
class TestKeyspaceMetadata(unittest.TestCase):
|
||||
|
||||
def test_export_as_string_user_types(self):
|
||||
keyspace_name = 'test'
|
||||
keyspace = KeyspaceMetadata(keyspace_name, True, 'SimpleStrategy', dict(replication_factor=3))
|
||||
keyspace.user_types['a'] = UserType(keyspace_name, 'a', ['one', 'two'],
|
||||
[self.mock_user_type('UserType', 'c'),
|
||||
self.mock_user_type('IntType', 'int')])
|
||||
keyspace.user_types['b'] = UserType(keyspace_name, 'b', ['one', 'two', 'three'],
|
||||
[self.mock_user_type('UserType', 'd'),
|
||||
self.mock_user_type('IntType', 'int'),
|
||||
self.mock_user_type('UserType', 'a')])
|
||||
keyspace.user_types['c'] = UserType(keyspace_name, 'c', ['one'],
|
||||
[self.mock_user_type('IntType', 'int')])
|
||||
keyspace.user_types['d'] = UserType(keyspace_name, 'd', ['one'],
|
||||
[self.mock_user_type('UserType', 'c')])
|
||||
|
||||
self.assertEqual("""CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} AND durable_writes = true;
|
||||
|
||||
CREATE TYPE test.c (
|
||||
one int
|
||||
);
|
||||
|
||||
CREATE TYPE test.a (
|
||||
one c,
|
||||
two int
|
||||
);
|
||||
|
||||
CREATE TYPE test.d (
|
||||
one c
|
||||
);
|
||||
|
||||
CREATE TYPE test.b (
|
||||
one d,
|
||||
two int,
|
||||
three a
|
||||
);""", keyspace.export_as_string())
|
||||
|
||||
def mock_user_type(self, cassname, typename):
|
||||
return Mock(**{'cassname': cassname, 'typename': typename, 'cql_parameterized_type.return_value': typename})
|
||||
|
||||
|
||||
class TestUserTypes(unittest.TestCase):
|
||||
|
||||
def test_as_cql_query(self):
|
||||
field_types = [IntegerType, AsciiType, TupleType.apply_parameters([IntegerType, AsciiType])]
|
||||
udt = UserType("ks1", "mytype", ["a", "b", "c"], field_types)
|
||||
self.assertEqual("CREATE TYPE ks1.mytype (a varint, b ascii, c tuple<varint, ascii>);", udt.as_cql_query(formatted=False))
|
||||
|
||||
self.assertEqual("""CREATE TYPE ks1.mytype (
|
||||
a varint,
|
||||
b ascii,
|
||||
c tuple<varint, ascii>
|
||||
);""", udt.as_cql_query(formatted=True))
|
||||
|
||||
def test_as_cql_query_name_escaping(self):
|
||||
udt = UserType("MyKeyspace", "MyType", ["AbA", "keyspace"], [AsciiType, AsciiType])
|
||||
self.assertEqual('CREATE TYPE "MyKeyspace"."MyType" ("AbA" ascii, "keyspace" ascii);', udt.as_cql_query(formatted=False))
|
||||
|
||||
@@ -17,7 +17,7 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.encoder import cql_encoders
|
||||
from cassandra.encoder import Encoder
|
||||
from cassandra.query import bind_params, ValueSequence
|
||||
from cassandra.query import PreparedStatement, BoundStatement
|
||||
from cassandra.cqltypes import Int32Type
|
||||
@@ -29,31 +29,31 @@ from six.moves import xrange
|
||||
class ParamBindingTest(unittest.TestCase):
|
||||
|
||||
def test_bind_sequence(self):
|
||||
result = bind_params("%s %s %s", (1, "a", 2.0), cql_encoders)
|
||||
result = bind_params("%s %s %s", (1, "a", 2.0), Encoder())
|
||||
self.assertEqual(result, "1 'a' 2.0")
|
||||
|
||||
def test_bind_map(self):
|
||||
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), cql_encoders)
|
||||
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), Encoder())
|
||||
self.assertEqual(result, "1 'a' 2.0")
|
||||
|
||||
def test_sequence_param(self):
|
||||
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), cql_encoders)
|
||||
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), Encoder())
|
||||
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
|
||||
|
||||
def test_generator_param(self):
|
||||
result = bind_params("%s", ((i for i in xrange(3)),), cql_encoders)
|
||||
result = bind_params("%s", ((i for i in xrange(3)),), Encoder())
|
||||
self.assertEqual(result, "[ 0 , 1 , 2 ]")
|
||||
|
||||
def test_none_param(self):
|
||||
result = bind_params("%s", (None,), cql_encoders)
|
||||
result = bind_params("%s", (None,), Encoder())
|
||||
self.assertEqual(result, "NULL")
|
||||
|
||||
def test_list_collection(self):
|
||||
result = bind_params("%s", (['a', 'b', 'c'],), cql_encoders)
|
||||
result = bind_params("%s", (['a', 'b', 'c'],), Encoder())
|
||||
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
|
||||
|
||||
def test_set_collection(self):
|
||||
result = bind_params("%s", (set(['a', 'b']),), cql_encoders)
|
||||
result = bind_params("%s", (set(['a', 'b']),), Encoder())
|
||||
self.assertIn(result, ("{ 'a' , 'b' }", "{ 'b' , 'a' }"))
|
||||
|
||||
def test_map_collection(self):
|
||||
@@ -61,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
|
||||
vals['a'] = 'a'
|
||||
vals['b'] = 'b'
|
||||
vals['c'] = 'c'
|
||||
result = bind_params("%s", (vals,), cql_encoders)
|
||||
result = bind_params("%s", (vals,), Encoder())
|
||||
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
|
||||
|
||||
def test_quote_escaping(self):
|
||||
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), cql_encoders)
|
||||
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), Encoder())
|
||||
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user