Merge branch 'master' of github.com:datastax/python-driver into PYTHON-96-test

This commit is contained in:
Joaquin Casares
2014-07-25 12:58:44 -05:00
24 changed files with 1305 additions and 249 deletions

View File

@@ -5,6 +5,15 @@ In Progress
Bug Fixes
---------
* Properly specify UDTs for columns in CREATE TABLE statements
* Avoid moving retries to a new host when using request ID zero (PYTHON-88)
* Don't ignore fetch_size arguments to Statement constructors (github-151)
* Allow disabling automatic paging on a per-statement basis when it's
enabled by default for the session (PYTHON-93)
* Raise ValueError when tuple query parameters for prepared statements
have extra items (PYTHON-98)
* Correctly encode nested tuples and UDTs for non-prepared statements (PYTHON-100)
* Raise TypeError when a string is used for contact_points (github #164)
* Include User Defined Types in KeyspaceMetadata.export_as_string() (PYTHON-96)
Other
-----

View File

@@ -4,6 +4,10 @@ Releasing
* If dependencies have changed, make sure ``debian/control``
is up to date
* Make sure all patches in ``debian/patches`` still apply cleanly
* Update the debian changelog with the new version::
dch -v '1.0.0'
* Update CHANGELOG.rst
* Update the version in ``cassandra/__init__.py``
* Commit the changelog and version changes

View File

@@ -42,9 +42,9 @@ from functools import partial, wraps
from itertools import groupby
from cassandra import (ConsistencyLevel, AuthenticationFailed,
OperationTimedOut, UnsupportedOperation)
InvalidRequest, OperationTimedOut, UnsupportedOperation)
from cassandra.connection import ConnectionException, ConnectionShutdown
from cassandra.encoder import cql_encode_all_types, cql_encoders
from cassandra.encoder import Encoder
from cassandra.protocol import (QueryMessage, ResultMessage,
ErrorMessage, ReadTimeoutErrorMessage,
WriteTimeoutErrorMessage,
@@ -65,7 +65,7 @@ from cassandra.pool import (_ReconnectionHandler, _HostReconnectionHandler,
NoConnectionsAvailable)
from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
BatchStatement, bind_params, QueryTrace, Statement,
named_tuple_factory, dict_factory)
named_tuple_factory, dict_factory, FETCH_SIZE_UNSET)
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the
# default because it's fastest. Otherwise, use asyncore.
@@ -380,7 +380,12 @@ class Cluster(object):
Any of the mutable Cluster attributes may be set as keyword arguments
to the constructor.
"""
self.contact_points = contact_points
if contact_points is not None:
if isinstance(contact_points, six.string_types):
raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings")
self.contact_points = contact_points
self.port = port
self.compression = compression
self.protocol_version = protocol_version
@@ -467,6 +472,53 @@ class Cluster(object):
self, self.control_connection_timeout)
def register_user_type(self, keyspace, user_type, klass):
"""
Registers a class to use to represent a particular user-defined type.
Query parameters for this user-defined type will be assumed to be
instances of `klass`. Result sets for this user-defined type will
be instances of `klass`. If no class is registered for a user-defined
type, a namedtuple will be used for result sets, and non-prepared
statements may not encode parameters for this type correctly.
`keyspace` is the name of the keyspace that the UDT is defined in.
`user_type` is the string name of the UDT to register the mapping
for.
`klass` should be a class with attributes whose names match the
fields of the user-defined type. The constructor must accepts kwargs
for each of the fields in the UDT.
This method should only be called after the type has been created
within Cassandra.
Example::
cluster = Cluster(protocol_version=3)
session = cluster.connect()
session.set_keyspace('mykeyspace')
session.execute("CREATE TYPE address (street text, zipcode int)")
session.execute("CREATE TABLE users (id int PRIMARY KEY, location address)")
# create a class to map to the "address" UDT
class Address(object):
def __init__(self, street, zipcode):
self.street = street
self.zipcode = zipcode
cluster.register_user_type('mykeyspace', 'address', Address)
# insert a row using an instance of Address
session.execute("INSERT INTO users (id, location) VALUES (%s, %s)",
(0, Address("123 Main St.", 78723)))
# results will include Address instances
results = session.execute("SELECT * FROM users")
row = results[0]
print row.id, row.location.street, row.location.zipcode
"""
self._user_types[keyspace][user_type] = klass
for session in self.sessions:
session.user_type_registered(keyspace, user_type, klass)
@@ -591,6 +643,8 @@ class Cluster(object):
raise Exception("Cluster is already shut down")
if not self._is_setup:
log.debug("Connecting to cluster, contact points: %s; protocol version: %s",
self.contact_points, self.protocol_version)
self.connection_class.initialize_reactor()
atexit.register(partial(_shutdown_cluster, self))
for address in self.contact_points:
@@ -1115,14 +1169,36 @@ class Session(object):
.. versionadded:: 2.1.0
"""
encoder = None
"""
A :class:`~cassandra.encoder.Encoder` instance that will be used when
formatting query parameters for non-prepared statements. This is not used
for prepared statements (because prepared statements give the driver more
information about what CQL types are expected, allowing it to accept a
wider range of python types).
The encoder uses a mapping from python types to encoder methods (for
specific CQL types). This mapping can be be modified by users as they see
fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping
values if possible, because they take precautions to avoid injections and
properly sanitize data.
Example::
cluster = Cluster()
session = cluster.connect("mykeyspace")
session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple
session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple<int, ascii>)")
session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')])
"""
_lock = None
_pools = None
_load_balancer = None
_metrics = None
_protocol_version = None
encoders = None
def __init__(self, cluster, hosts):
self.cluster = cluster
self.hosts = hosts
@@ -1133,7 +1209,7 @@ class Session(object):
self._metrics = cluster.metrics
self._protocol_version = self.cluster.protocol_version
self.encoders = cql_encoders.copy()
self.encoder = Encoder()
# create connection pools in parallel
futures = []
@@ -1246,8 +1322,10 @@ class Session(object):
cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level
fetch_size = query.fetch_size
if not fetch_size and self._protocol_version >= 2:
if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2:
fetch_size = self.default_fetch_size
elif self._protocol_version == 1:
fetch_size = None
if self._protocol_version >= 3 and self.use_client_timestamp:
timestamp = int(time.time() * 1e6)
@@ -1259,7 +1337,7 @@ class Session(object):
if six.PY2 and isinstance(query_string, six.text_type):
query_string = query_string.encode('utf-8')
if parameters:
query_string = bind_params(query_string, parameters, self.encoders)
query_string = bind_params(query_string, parameters, self.encoder)
message = QueryMessage(
query_string, cl, query.serial_consistency_level,
fetch_size, timestamp=timestamp)
@@ -1504,15 +1582,25 @@ class Session(object):
mapping from a user-defined type to a class. Intended for internal
use only.
"""
type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type]
try:
ks_meta = self.cluster.metadata.keyspaces[keyspace]
except KeyError:
raise UserTypeDoesNotExist(
'Keyspace %s does not exist or has not been discovered by the driver' % (keyspace,))
try:
type_meta = ks_meta.user_types[user_type]
except KeyError:
raise UserTypeDoesNotExist(
'User type %s does not exist in keyspace %s' % (user_type, keyspace))
def encode(val):
return '{ %s }' % ' , '.join('%s : %s' % (
field_name,
cql_encode_all_types(getattr(val, field_name))
self.encoder.cql_encode_all_types(getattr(val, field_name))
) for field_name in type_meta.field_names)
self.encoders[klass] = encode
self.encoder.mapping[klass] = encode
def submit(self, fn, *args, **kwargs):
""" Internal """
@@ -1523,6 +1611,13 @@ class Session(object):
return dict((host, pool.get_state()) for host, pool in self._pools.items())
class UserTypeDoesNotExist(Exception):
"""
An attempt was made to use a user-defined type that does not exist.
"""
pass
class _ControlReconnectionHandler(_ReconnectionHandler):
"""
Internal
@@ -1807,19 +1902,37 @@ class ControlConnection(object):
queries = [
QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl),
QueryMessage(query=self._SELECT_COLUMN_FAMILIES, consistency_level=cl),
QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl)
QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl),
QueryMessage(query=self._SELECT_USERTYPES, consistency_level=cl)
]
if self._protocol_version >= 3:
queries.append(QueryMessage(query=self._SELECT_USERTYPES, consistency_level=cl))
ks_result, cf_result, col_result, types_result = connection.wait_for_responses(*queries)
responses = connection.wait_for_responses(*queries, fail_on_error=False)
(ks_success, ks_result), (cf_success, cf_result), (col_success, col_result), (types_success, types_result) = responses
if ks_success:
ks_result = dict_factory(*ks_result.results)
else:
raise ks_result
if cf_success:
cf_result = dict_factory(*cf_result.results)
else:
raise cf_result
if col_success:
col_result = dict_factory(*col_result.results)
else:
raise col_result
# if we're connected to Cassandra < 2.1, the usertypes table will not exist
if types_success:
types_result = dict_factory(*types_result.results) if types_result.results else {}
else:
ks_result, cf_result, col_result = connection.wait_for_responses(*queries)
types_result = {}
ks_result = dict_factory(*ks_result.results)
cf_result = dict_factory(*cf_result.results)
col_result = dict_factory(*col_result.results)
if isinstance(types_result, InvalidRequest):
log.debug("[control connection] user types table not found")
types_result = {}
else:
raise types_result
log.debug("[control connection] Fetched schema, rebuilding metadata")
self._cluster.metadata.rebuild_schema(ks_result, types_result, cf_result, col_result)
@@ -2558,7 +2671,7 @@ class ResponseFuture(object):
# to retry the operation
return
if reuse_connection and self._query(self._current_host):
if reuse_connection and self._query(self._current_host) is not None:
return
# otherwise, move onto another host

View File

@@ -302,10 +302,19 @@ class Connection(object):
return self.wait_for_responses(msg, timeout=timeout)[0]
def wait_for_responses(self, *msgs, **kwargs):
"""
Returns a list of (success, response) tuples. If success
is False, response will be an Exception. Otherwise, response
will be the normal query response.
If fail_on_error was left as True and one of the requests
failed, the corresponding Exception will be raised.
"""
if self.is_closed or self.is_defunct:
raise ConnectionShutdown("Connection %s is already closed" % (self, ))
timeout = kwargs.get('timeout')
waiter = ResponseWaiter(self, len(msgs))
fail_on_error = kwargs.get('fail_on_error', True)
waiter = ResponseWaiter(self, len(msgs), fail_on_error)
# busy wait for sufficient space on the connection
messages_sent = 0
@@ -677,9 +686,10 @@ class Connection(object):
class ResponseWaiter(object):
def __init__(self, connection, num_responses):
def __init__(self, connection, num_responses, fail_on_error):
self.connection = connection
self.pending = num_responses
self.fail_on_error = fail_on_error
self.error = None
self.responses = [None] * num_responses
self.event = Event()
@@ -688,15 +698,33 @@ class ResponseWaiter(object):
with self.connection.lock:
self.connection.in_flight -= 1
if isinstance(response, Exception):
self.error = response
self.event.set()
else:
self.responses[index] = response
self.pending -= 1
if not self.pending:
if hasattr(response, 'to_exception'):
response = response.to_exception()
if self.fail_on_error:
self.error = response
self.event.set()
else:
self.responses[index] = (False, response)
else:
if not self.fail_on_error:
self.responses[index] = (True, response)
else:
self.responses[index] = response
self.pending -= 1
if not self.pending:
self.event.set()
def deliver(self, timeout=None):
"""
If fail_on_error was set to False, a list of (success, response)
tuples will be returned. If success is False, response will be
an Exception. Otherwise, response will be the normal query response.
If fail_on_error was left as True and one of the requests
failed, the corresponding Exception will be raised. Otherwise,
the normal response will be returned.
"""
self.event.wait(timeout)
if self.error:
raise self.error

View File

@@ -806,6 +806,10 @@ class TupleType(_ParameterizedType):
@classmethod
def serialize_safe(cls, val, protocol_version):
if len(val) > len(cls.subtypes):
raise ValueError("Expected %d items in a tuple, but got %d: %s" %
(len(cls.subtypes), len(val), val))
proto_version = max(3, protocol_version)
buf = io.BytesIO()
for item, subtype in zip(val, cls.subtypes):

View File

@@ -11,6 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
These functions are used to convert Python objects into CQL strings.
When non-prepared statements are executed, these encoder functions are
called on each query parameter.
"""
import logging
log = logging.getLogger(__name__)
@@ -44,104 +49,156 @@ def cql_quote(term):
return str(term)
def cql_encode_none(val):
return 'NULL'
def cql_encode_unicode(val):
return cql_quote(val.encode('utf-8'))
def cql_encode_str(val):
return cql_quote(val)
if six.PY3:
def cql_encode_bytes(val):
return (b'0x' + hexlify(val)).decode('utf-8')
elif sys.version_info >= (2, 7):
def cql_encode_bytes(val): # noqa
return b'0x' + hexlify(val)
else:
# python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(val): # noqa
return b'0x' + hexlify(buffer(val))
def cql_encode_object(val):
return str(val)
def cql_encode_datetime(val):
timestamp = calendar.timegm(val.utctimetuple())
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
def cql_encode_date(val):
return "'%s'" % val.strftime('%Y-%m-%d-0000')
def cql_encode_sequence(val):
return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v)
for v in val)
cql_encode_tuple = cql_encode_sequence
def cql_encode_map_collection(val):
return '{ %s }' % ' , '.join('%s : %s' % (
cql_encode_all_types(k),
cql_encode_all_types(v)
) for k, v in six.iteritems(val))
def cql_encode_list_collection(val):
return '[ %s ]' % ' , '.join(map(cql_encode_all_types, val))
def cql_encode_set_collection(val):
return '{ %s }' % ' , '.join(map(cql_encode_all_types, val))
def cql_encode_all_types(val):
return cql_encoders.get(type(val), cql_encode_object)(val)
cql_encoders = {
float: cql_encode_object,
bytearray: cql_encode_bytes,
str: cql_encode_str,
int: cql_encode_object,
UUID: cql_encode_object,
datetime.datetime: cql_encode_datetime,
datetime.date: cql_encode_date,
dict: cql_encode_map_collection,
OrderedDict: cql_encode_map_collection,
list: cql_encode_list_collection,
tuple: cql_encode_list_collection,
set: cql_encode_set_collection,
frozenset: cql_encode_set_collection,
types.GeneratorType: cql_encode_list_collection
}
if six.PY2:
cql_encoders.update({
unicode: cql_encode_unicode,
buffer: cql_encode_bytes,
long: cql_encode_object,
types.NoneType: cql_encode_none,
})
else:
cql_encoders.update({
memoryview: cql_encode_bytes,
bytes: cql_encode_bytes,
type(None): cql_encode_none,
})
# sortedset is optional
try:
from blist import sortedset
cql_encoders.update({
sortedset: cql_encode_set_collection
})
except ImportError:
class ValueSequence(list):
pass
class Encoder(object):
"""
A container for mapping python types to CQL string literals when working
with non-prepared statements. The type :attr:`~.Encoder.mapping` can be
directly customized by users.
"""
mapping = None
"""
A map of python types to encoder functions.
"""
def __init__(self):
self.mapping = {
float: self.cql_encode_object,
bytearray: self.cql_encode_bytes,
str: self.cql_encode_str,
int: self.cql_encode_object,
UUID: self.cql_encode_object,
datetime.datetime: self.cql_encode_datetime,
datetime.date: self.cql_encode_date,
dict: self.cql_encode_map_collection,
OrderedDict: self.cql_encode_map_collection,
list: self.cql_encode_list_collection,
tuple: self.cql_encode_list_collection,
set: self.cql_encode_set_collection,
frozenset: self.cql_encode_set_collection,
types.GeneratorType: self.cql_encode_list_collection,
ValueSequence: self.cql_encode_sequence
}
if six.PY2:
self.mapping.update({
unicode: self.cql_encode_unicode,
buffer: self.cql_encode_bytes,
long: self.cql_encode_object,
types.NoneType: self.cql_encode_none,
})
else:
self.mapping.update({
memoryview: self.cql_encode_bytes,
bytes: self.cql_encode_bytes,
type(None): self.cql_encode_none,
})
# sortedset is optional
try:
from blist import sortedset
self.mapping.update({
sortedset: self.cql_encode_set_collection
})
except ImportError:
pass
def cql_encode_none(self, val):
"""
Converts :const:`None` to the string 'NULL'.
"""
return 'NULL'
def cql_encode_unicode(self, val):
"""
Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping.
"""
return cql_quote(val.encode('utf-8'))
def cql_encode_str(self, val):
"""
Escapes quotes in :class:`str` objects.
"""
return cql_quote(val)
if six.PY3:
def cql_encode_bytes(self, val):
return (b'0x' + hexlify(val)).decode('utf-8')
elif sys.version_info >= (2, 7):
def cql_encode_bytes(self, val): # noqa
return b'0x' + hexlify(val)
else:
# python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(self, val): # noqa
return b'0x' + hexlify(buffer(val))
def cql_encode_object(self, val):
"""
Default encoder for all objects that do not have a specific encoder function
registered. This function simply calls :meth:`str()` on the object.
"""
return str(val)
def cql_encode_datetime(self, val):
"""
Converts a :class:`datetime.datetime` object to a (string) integer timestamp
with millisecond precision.
"""
timestamp = calendar.timegm(val.utctimetuple())
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
def cql_encode_date(self, val):
"""
Converts a :class:`datetime.date` object to a string with format
``YYYY-MM-DD-0000``.
"""
return "'%s'" % val.strftime('%Y-%m-%d-0000')
def cql_encode_sequence(self, val):
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``IN`` value lists.
"""
return '( %s )' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v)
for v in val)
cql_encode_tuple = cql_encode_sequence
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``tuple`` type columns.
"""
def cql_encode_map_collection(self, val):
"""
Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``.
This is suitable for ``map`` type columns.
"""
return '{ %s }' % ' , '.join('%s : %s' % (
self.mapping.get(type(k), self.cql_encode_object)(k),
self.mapping.get(type(v), self.cql_encode_object)(v)
) for k, v in six.iteritems(val))
def cql_encode_list_collection(self, val):
"""
Converts a sequence to a string of the form ``[item1, item2, ...]``. This
is suitable for ``list`` type columns.
"""
return '[ %s ]' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
def cql_encode_set_collection(self, val):
"""
Converts a sequence to a string of the form ``{item1, item2, ...}``. This
is suitable for ``set`` type columns.
"""
return '{ %s }' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
def cql_encode_all_types(self, val):
"""
Converts any type into a CQL string, defaulting to ``cql_encode_object``
if :attr:`~Encoder.mapping` does not contain an entry for the type.
"""
return self.mapping.get(type(val), self.cql_encode_object)(val)

View File

@@ -450,7 +450,7 @@ class SimpleStrategy(ReplicationStrategy):
while len(hosts) < self.replication_factor and j < len(ring):
token = ring[(i + j) % len(ring)]
host = token_to_host_owner[token]
if not host in hosts:
if host not in hosts:
hosts.append(host)
j += 1
@@ -597,7 +597,7 @@ class KeyspaceMetadata(object):
self.user_types = {}
def export_as_string(self):
return "\n".join([self.as_cql_query()] + [t.export_as_string() for t in self.tables.values()])
return "\n\n".join([self.as_cql_query()] + self.user_type_strings() + [t.export_as_string() for t in self.tables.values()])
def as_cql_query(self):
ret = "CREATE KEYSPACE %s WITH replication = %s " % (
@@ -605,13 +605,51 @@ class KeyspaceMetadata(object):
self.replication_strategy.export_for_schema())
return ret + (' AND durable_writes = %s;' % ("true" if self.durable_writes else "false"))
def user_type_strings(self):
user_type_strings = []
types = self.user_types.copy()
keys = sorted(types.keys())
for k in keys:
if k in types:
self.resolve_user_types(k, types, user_type_strings)
return user_type_strings
def resolve_user_types(self, key, types, user_type_strings):
user_type = types.pop(key)
for field_type in user_type.field_types:
if field_type.cassname == 'UserType' and field_type.typename in types:
self.resolve_user_types(field_type.typename, types, user_type_strings)
user_type_strings.append(user_type.as_cql_query(formatted=True))
class UserType(object):
"""
A user defined type, as created by ``CREATE TYPE`` statements.
User-defined types were introduced in Cassandra 2.1.
.. versionadded:: 2.1.0
"""
keyspace = None
"""
The string name of the keyspace in which this type is defined.
"""
name = None
"""
The name of this type.
"""
field_names = None
"""
An ordered list of the names for each field in this user-defined type.
"""
field_types = None
"""
An ordered list of the types for each field in this user-defined type.
"""
def __init__(self, keyspace, name, field_names, field_types):
self.keyspace = keyspace
@@ -619,6 +657,32 @@ class UserType(object):
self.field_names = field_names
self.field_types = field_types
def as_cql_query(self, formatted=False):
"""
Returns a CQL query that can be used to recreate this type.
If `formatted` is set to :const:`True`, extra whitespace will
be added to make the query more readable.
"""
ret = "CREATE TYPE %s.%s (%s" % (
protect_name(self.keyspace),
protect_name(self.name),
"\n" if formatted else "")
if formatted:
field_join = ",\n"
padding = " "
else:
field_join = ", "
padding = ""
fields = []
for field_name, field_type in zip(self.field_names, self.field_types):
fields.append("%s %s" % (protect_name(field_name), field_type.cql_parameterized_type()))
ret += field_join.join("%s%s" % (padding, field) for field in fields)
ret += "\n);" if formatted else ");"
return ret
class TableMetadata(object):
"""

View File

@@ -17,7 +17,6 @@ Connection pooling and host management.
"""
import logging
import re
import socket
import time
from threading import Lock, RLock, Condition
@@ -42,13 +41,6 @@ class NoConnectionsAvailable(Exception):
pass
# example matches:
# 1.0.0
# 1.0.0-beta1
# 2.0-SNAPSHOT
version_re = re.compile(r"(?P<major>\d+)\.(?P<minor>\d+)(?:\.(?P<patch>\d+))?(?:-(?P<label>\w+))?")
class Host(object):
"""
Represents a single Cassandra node.
@@ -72,12 +64,6 @@ class Host(object):
up or down.
"""
version = None
"""
A tuple representing the Cassandra version for this host. This will
remain as :const:`None` if the version is unknown.
"""
_datacenter = None
_rack = None
_reconnection_handler = None
@@ -115,14 +101,6 @@ class Host(object):
self._datacenter = datacenter
self._rack = rack
def set_version(self, version_string):
match = version_re.match(version_string)
if match is not None:
version = [int(match.group('major')), int(match.group('minor')), int(match.group('patch') or 0)]
if match.group('label'):
version.append(match.group('label'))
self.version = tuple(version)
def set_up(self):
if not self.is_up:
log.debug("Host %s is now marked up", self.address)

View File

@@ -27,8 +27,8 @@ import six
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cqltypes import unix_time_from_uuid1
from cassandra.encoder import (cql_encoders, cql_encode_object,
cql_encode_sequence)
from cassandra.encoder import Encoder
import cassandra.encoder
from cassandra.util import OrderedDict
import logging
@@ -57,7 +57,7 @@ def tuple_factory(colnames, rows):
Example::
>>> from cassandra.query import named_tuple_factory
>>> from cassandra.query import tuple_factory
>>> session = cluster.connect('mykeyspace')
>>> session.row_factory = tuple_factory
>>> rows = session.execute("SELECT name, age FROM users LIMIT 1")
@@ -133,6 +133,9 @@ def ordered_dict_factory(colnames, rows):
return [OrderedDict(zip(colnames, row)) for row in rows]
FETCH_SIZE_UNSET = object()
class Statement(object):
"""
An abstract class representing a single query. There are three subclasses:
@@ -160,7 +163,7 @@ class Statement(object):
the Session this is executed in will be used.
"""
fetch_size = None
fetch_size = FETCH_SIZE_UNSET
"""
How many rows will be fetched at a time. This overrides the default
of :attr:`.Session.default_fetch_size`
@@ -175,14 +178,14 @@ class Statement(object):
_routing_key = None
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
serial_consistency_level=None, fetch_size=None):
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET):
self.retry_policy = retry_policy
if consistency_level is not None:
self.consistency_level = consistency_level
if serial_consistency_level is not None:
self.serial_consistency_level = serial_consistency_level
if fetch_size is not None:
self.fetch_size = None
if fetch_size is not FETCH_SIZE_UNSET:
self.fetch_size = fetch_size
self._routing_key = routing_key
def _get_routing_key(self):
@@ -315,9 +318,11 @@ class PreparedStatement(object):
_protocol_version = None
fetch_size = FETCH_SIZE_UNSET
def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace,
protocol_version, consistency_level=None, serial_consistency_level=None,
fetch_size=None):
fetch_size=FETCH_SIZE_UNSET):
self.column_metadata = column_metadata
self.query_id = query_id
self.routing_key_indexes = routing_key_indexes
@@ -326,7 +331,8 @@ class PreparedStatement(object):
self._protocol_version = protocol_version
self.consistency_level = consistency_level
self.serial_consistency_level = serial_consistency_level
self.fetch_size = fetch_size
if fetch_size is not FETCH_SIZE_UNSET:
self.fetch_size = fetch_size
@classmethod
def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace, protocol_version):
@@ -619,8 +625,8 @@ class BatchStatement(Statement):
"""
if isinstance(statement, six.string_types):
if parameters:
encoders = cql_encoders if self._session is None else self._session.encoders
statement = bind_params(statement, parameters, encoders)
encoder = Encoder() if self._session is None else self._session.encoder
statement = bind_params(statement, parameters, encoder)
self._statements_and_parameters.append((False, statement, ()))
elif isinstance(statement, PreparedStatement):
query_id = statement.query_id
@@ -638,8 +644,8 @@ class BatchStatement(Statement):
# it must be a SimpleStatement
query_string = statement.query_string
if parameters:
encoders = cql_encoders if self._session is None else self._session.encoders
query_string = bind_params(query_string, parameters, encoders)
encoder = Encoder() if self._session is None else self._session.encoder
query_string = bind_params(query_string, parameters, encoder)
self._statements_and_parameters.append((False, query_string, ()))
return self
@@ -659,33 +665,27 @@ class BatchStatement(Statement):
__repr__ = __str__
class ValueSequence(object):
"""
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
ValueSequence = cassandra.encoder.ValueSequence
"""
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
This is typically needed when supplying a list of keys to select.
For example::
This is typically needed when supplying a list of keys to select.
For example::
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
"""
def __init__(self, sequence):
self.sequence = sequence
def __str__(self):
return cql_encode_sequence(self.sequence)
"""
def bind_params(query, params, encoders):
def bind_params(query, params, encoder):
if isinstance(params, dict):
return query % dict((k, encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params))
else:
return query % tuple(encoders.get(type(v), cql_encode_object)(v) for v in params)
return query % tuple(encoder.cql_encode_all_types(v) for v in params)
class TraceUnavailable(Exception):

View File

@@ -43,6 +43,8 @@
.. automethod:: shutdown
.. automethod:: register_user_type
.. automethod:: register_listener
.. automethod:: unregister_listener
@@ -63,6 +65,10 @@
.. autoattribute:: default_fetch_size
.. autoattribute:: use_client_timestamp
.. autoattribute:: encoder
.. automethod:: execute(statement[, parameters][, timeout][, trace])
.. automethod:: execute_async(statement[, parameters][, trace])
@@ -98,3 +104,5 @@
.. autoexception:: NoHostAvailable ()
:members:
.. autoexception:: UserTypeDoesNotExist ()

View File

@@ -0,0 +1,36 @@
``cassandra.encoder`` - Encoders for non-prepared Statements
============================================================
.. module:: cassandra.encoder
.. autoclass:: Encoder ()
.. autoattribute:: cassandra.encoder.Encoder.mapping
.. automethod:: cassandra.encoder.Encoder.cql_encode_none ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_object ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_all_types ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_sequence ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_str ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_unicode ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_bytes ()
Converts strings, buffers, and bytearrays into CQL blob literals.
.. automethod:: cassandra.encoder.Encoder.cql_encode_datetime ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_date ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_map_collection ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_list_collection ()
.. automethod:: cassandra.encoder.Encoder.cql_encode_set_collection ()
.. automethod:: cql_encode_tuple ()

View File

@@ -2,7 +2,6 @@
===========================================
.. module:: cassandra.metrics
:members:
.. autoclass:: cassandra.metrics.Metrics ()
:members:

View File

@@ -34,8 +34,18 @@
.. autoattribute:: COUNTER
.. autoclass:: ValueSequence
:members:
.. autoclass:: cassandra.query.ValueSequence
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
This is typically needed when supplying a list of keys to select.
For example::
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
.. autoclass:: QueryTrace ()
:members:

View File

@@ -12,6 +12,7 @@ API Documentation
cassandra/metrics
cassandra/query
cassandra/pool
cassandra/encoder
cassandra/decoder
cassandra/concurrent
cassandra/connection

View File

@@ -137,7 +137,7 @@ when you execute:
"""
INSERT INTO users (name, credits, user_id)
VALUES (%s, %s, %s)
"""
""",
("John O'Reilly", 42, uuid.uuid1())
)

View File

@@ -40,6 +40,7 @@ MULTIDC_CLUSTER_NAME = 'multidc_test_cluster'
CCM_CLUSTER = None
CASSANDRA_DIR = os.getenv('CASSANDRA_DIR', None)
CASSANDRA_HOME = os.getenv('CASSANDRA_HOME', None)
CASSANDRA_VERSION = os.getenv('CASSANDRA_VERSION', '2.0.9')
if CASSANDRA_VERSION.startswith('1'):
@@ -97,6 +98,8 @@ def get_node(node_id):
def setup_package():
if CASSANDRA_DIR:
log.info("Using Cassandra dir: %s", CASSANDRA_DIR)
elif CASSANDRA_HOME:
log.info("Using Cassandra home: %s", CASSANDRA_HOME)
else:
log.info('Using Cassandra version: %s', CASSANDRA_VERSION)
try:

View File

@@ -0,0 +1,139 @@
# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from decimal import Decimal
import datetime
from uuid import UUID
import pytz
try:
from blist import sortedset
except ImportError:
sortedset = set # noqa
DATA_TYPE_PRIMITIVES = [
'ascii',
'bigint',
'blob',
'boolean',
# 'counter', counters are not allowed inside tuples
'decimal',
'double',
'float',
'inet',
'int',
'text',
'timestamp',
'timeuuid',
'uuid',
'varchar',
'varint',
]
DATA_TYPE_NON_PRIMITIVE_NAMES = [
'list',
'set',
'map',
'tuple'
]
def get_sample_data():
"""
Create a standard set of sample inputs for testing.
"""
sample_data = {}
for datatype in DATA_TYPE_PRIMITIVES:
if datatype == 'ascii':
sample_data[datatype] = 'ascii'
elif datatype == 'bigint':
sample_data[datatype] = 2 ** 63 - 1
elif datatype == 'blob':
sample_data[datatype] = bytearray(b'hello world')
elif datatype == 'boolean':
sample_data[datatype] = True
elif datatype == 'counter':
# Not supported in an insert statement
pass
elif datatype == 'decimal':
sample_data[datatype] = Decimal('12.3E+7')
elif datatype == 'double':
sample_data[datatype] = 1.23E+8
elif datatype == 'float':
sample_data[datatype] = 3.4028234663852886e+38
elif datatype == 'inet':
sample_data[datatype] = '123.123.123.123'
elif datatype == 'int':
sample_data[datatype] = 2147483647
elif datatype == 'text':
sample_data[datatype] = 'text'
elif datatype == 'timestamp':
sample_data[datatype] = datetime.datetime.fromtimestamp(872835240, tz=pytz.timezone('America/New_York')).astimezone(pytz.UTC).replace(tzinfo=None)
elif datatype == 'timeuuid':
sample_data[datatype] = UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66')
elif datatype == 'uuid':
sample_data[datatype] = UUID('067e6162-3b6f-4ae2-a171-2470b63dff00')
elif datatype == 'varchar':
sample_data[datatype] = 'varchar'
elif datatype == 'varint':
sample_data[datatype] = int(str(2147483647) + '000')
else:
raise Exception('Missing handling of %s.' % datatype)
return sample_data
SAMPLE_DATA = get_sample_data()
def get_sample(datatype):
"""
Helper method to access created sample data
"""
return SAMPLE_DATA[datatype]
def get_nonprim_sample(non_prim_type, datatype):
"""
Helper method to access created sample data for non-primitives
"""
if non_prim_type == 'list':
return [get_sample(datatype), get_sample(datatype)]
elif non_prim_type == 'set':
return sortedset([get_sample(datatype)])
elif non_prim_type == 'map':
if datatype == 'blob':
return {get_sample('ascii'): get_sample(datatype)}
else:
return {get_sample(datatype): get_sample(datatype)}
elif non_prim_type == 'tuple':
return (get_sample(datatype),)
else:
raise Exception('Missing handling of non-primitive type {0}.'.format(non_prim_type))

View File

@@ -18,9 +18,8 @@ except ImportError:
import unittest # noqa
from cassandra import ConsistencyLevel
from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
SimpleStatement, BatchStatement, BatchType,
dict_factory)
from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement,
BatchStatement, BatchType, dict_factory)
from cassandra.cluster import Cluster
from cassandra.policies import HostDistance
@@ -45,14 +44,6 @@ class QueryTest(unittest.TestCase):
session.execute(bound)
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_value_sequence(self):
"""
Test the output of ValueSequences()
"""
my_user_ids = ('alice', 'bob', 'charles')
self.assertEqual(str(ValueSequence(my_user_ids)), "( 'alice' , 'bob' , 'charles' )")
def test_trace_prints_okay(self):
"""
Code coverage to ensure trace prints to string without error

View File

@@ -26,7 +26,7 @@ from itertools import cycle, count
from six.moves import range
from threading import Event
from cassandra.cluster import Cluster
from cassandra.cluster import Cluster, PagedResult
from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args
from cassandra.policies import HostDistance
from cassandra.query import SimpleStatement
@@ -281,3 +281,79 @@ class QueryPagingTests(unittest.TestCase):
for (success, result) in results:
self.assertTrue(success)
self.assertEquals(100, len(list(result)))
def test_fetch_size(self):
"""
Ensure per-statement fetch_sizes override the default fetch size.
"""
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test")
self.session.default_fetch_size = 10
result = self.session.execute(prepared, [])
self.assertIsInstance(result, PagedResult)
self.session.default_fetch_size = 2000
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
self.session.default_fetch_size = None
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
self.session.default_fetch_size = 10
prepared.fetch_size = 2000
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
prepared.fetch_size = None
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
prepared.fetch_size = 10
result = self.session.execute(prepared, [])
self.assertIsInstance(result, PagedResult)
prepared.fetch_size = 2000
bound = prepared.bind([])
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
prepared.fetch_size = None
bound = prepared.bind([])
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
prepared.fetch_size = 10
bound = prepared.bind([])
result = self.session.execute(bound, [])
self.assertIsInstance(result, PagedResult)
bound.fetch_size = 2000
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
bound.fetch_size = None
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
bound.fetch_size = 10
result = self.session.execute(bound, [])
self.assertIsInstance(result, PagedResult)
s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None)
result = self.session.execute(s, [])
self.assertIsInstance(result, list)
s = SimpleStatement("SELECT * FROM test3rf.test")
result = self.session.execute(s, [])
self.assertIsInstance(result, PagedResult)
s = SimpleStatement("SELECT * FROM test3rf.test")
s.fetch_size = None
result = self.session.execute(s, [])
self.assertIsInstance(result, list)

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.integration.datatype_utils import get_sample, DATA_TYPE_PRIMITIVES, DATA_TYPE_NON_PRIMITIVE_NAMES
try:
import unittest2 as unittest
@@ -33,7 +34,6 @@ except ImportError:
from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.encoder import cql_encode_tuple
from cassandra.query import dict_factory
from cassandra.util import OrderedDict
@@ -87,8 +87,8 @@ class TypeTests(unittest.TestCase):
s.execute(query, params)
expected_vals = [
'key1',
bytearray(b'blobyblob')
'key1',
bytearray(b'blobyblob')
]
results = s.execute("SELECT * FROM mytable")
@@ -404,12 +404,18 @@ class TypeTests(unittest.TestCase):
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
def test_tuple_type(self):
"""
Basic test of tuple functionality
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.encoders[tuple] = cql_encode_tuple
# use this encoder in order to insert tuples
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_type
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
@@ -428,14 +434,241 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT b FROM mytable WHERE a=1")[0]
self.assertEqual(partial_result, result.b)
# test single value tuples
subpartial = ('zoo',)
subpartial_result = subpartial + (None, None)
s.execute("INSERT INTO mytable (a, b) VALUES (2, %s)", parameters=(subpartial,))
result = s.execute("SELECT b FROM mytable WHERE a=2")[0]
self.assertEqual(subpartial_result, result.b)
# test prepared statement
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(prepared, parameters=(2, complete))
s.execute(prepared, parameters=(3, partial))
s.execute(prepared, parameters=(3, complete))
s.execute(prepared, parameters=(4, partial))
s.execute(prepared, parameters=(5, subpartial))
# extra items in the tuple should result in an error
self.assertRaises(ValueError, s.execute, prepared, parameters=(0, (1, 2, 3, 4, 5, 6)))
prepared = s.prepare("SELECT b FROM mytable WHERE a=?")
self.assertEqual(complete, s.execute(prepared, (2,))[0].b)
self.assertEqual(partial_result, s.execute(prepared, (3,))[0].b)
self.assertEqual(complete, s.execute(prepared, (3,))[0].b)
self.assertEqual(partial_result, s.execute(prepared, (4,))[0].b)
self.assertEqual(subpartial_result, s.execute(prepared, (5,))[0].b)
def test_tuple_type_varying_lengths(self):
"""
Test tuple types of lengths of 1, 2, 3, and 384 to ensure edge cases work
as expected.
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
s.row_factory = dict_factory
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_type_varying_lengths
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_tuple_type_varying_lengths")
# programmatically create the table with tuples of said sizes
lengths = (1, 2, 3, 384)
value_schema = []
for i in lengths:
value_schema += [' v_%s tuple<%s>' % (i, ', '.join(['int'] * i))]
s.execute("CREATE TABLE mytable (k int PRIMARY KEY, %s)" % (', '.join(value_schema),))
# insert tuples into same key using different columns
# and verify the results
for i in lengths:
created_tuple = tuple(range(0, i))
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
self.assertEqual(tuple(created_tuple), result['v_%s' % i])
def test_tuple_primitive_subtypes(self):
"""
Ensure tuple subtypes are appropriately handled.
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_primitive_subtypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_tuple_primitive_subtypes")
s.execute("CREATE TABLE mytable ("
"k int PRIMARY KEY, "
"v tuple<%s>)" % ','.join(DATA_TYPE_PRIMITIVES))
for i in range(len(DATA_TYPE_PRIMITIVES)):
# create tuples to be written and ensure they match with the expected response
# responses have trailing None values for every element that has not been written
created_tuple = [get_sample(DATA_TYPE_PRIMITIVES[j]) for j in range(i + 1)]
response_tuple = tuple(created_tuple + [None for j in range(len(DATA_TYPE_PRIMITIVES) - i - 1)])
written_tuple = tuple(created_tuple)
s.execute("INSERT INTO mytable (k, v) VALUES (%s, %s)", (i, written_tuple))
result = s.execute("SELECT v FROM mytable WHERE k=%s", (i,))[0]
self.assertEqual(response_tuple, result.v)
def test_tuple_non_primitive_subtypes(self):
"""
Ensure tuple subtypes are appropriately handled for maps, sets, and lists.
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
s.row_factory = dict_factory
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_non_primitive_subtypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_tuple_non_primitive_subtypes")
values = []
# create list values
for datatype in DATA_TYPE_PRIMITIVES:
values.append('v_{} tuple<list<{}>>'.format(len(values), datatype))
# create set values
for datatype in DATA_TYPE_PRIMITIVES:
values.append('v_{} tuple<set<{}>>'.format(len(values), datatype))
# create map values
for datatype in DATA_TYPE_PRIMITIVES:
datatype_1 = datatype_2 = datatype
if datatype == 'blob':
# unhashable type: 'bytearray'
datatype_1 = 'ascii'
values.append('v_{} tuple<map<{}, {}>>'.format(len(values), datatype_1, datatype_2))
# make sure we're testing all non primitive data types in the future
if set(DATA_TYPE_NON_PRIMITIVE_NAMES) != set(['tuple', 'list', 'map', 'set']):
raise NotImplemented('Missing datatype not implemented: {}'.format(
set(DATA_TYPE_NON_PRIMITIVE_NAMES) - set(['tuple', 'list', 'map', 'set'])
))
# create table
s.execute("CREATE TABLE mytable ("
"k int PRIMARY KEY, "
"%s)" % ', '.join(values))
i = 0
# test tuple<list<datatype>>
for datatype in DATA_TYPE_PRIMITIVES:
created_tuple = tuple([[get_sample(datatype)]])
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
i += 1
# test tuple<set<datatype>>
for datatype in DATA_TYPE_PRIMITIVES:
created_tuple = tuple([sortedset([get_sample(datatype)])])
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
i += 1
# test tuple<map<datatype, datatype>>
for datatype in DATA_TYPE_PRIMITIVES:
if datatype == 'blob':
# unhashable type: 'bytearray'
created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}])
else:
created_tuple = tuple([{get_sample(datatype): get_sample(datatype)}])
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
i += 1
def nested_tuples_schema_helper(self, depth):
"""
Helper method for creating nested tuple schema
"""
if depth == 0:
return 'int'
else:
return 'tuple<%s>' % self.nested_tuples_schema_helper(depth - 1)
def nested_tuples_creator_helper(self, depth):
"""
Helper method for creating nested tuples
"""
if depth == 0:
return 303
else:
return (self.nested_tuples_creator_helper(depth - 1), )
def test_nested_tuples(self):
"""
Ensure nested are appropriately handled.
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
s.row_factory = dict_factory
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""CREATE KEYSPACE test_nested_tuples
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_nested_tuples")
# create a table with multiple sizes of nested tuples
s.execute("CREATE TABLE mytable ("
"k int PRIMARY KEY, "
"v_1 %s,"
"v_2 %s,"
"v_3 %s,"
"v_128 %s"
")" % (self.nested_tuples_schema_helper(1),
self.nested_tuples_schema_helper(2),
self.nested_tuples_schema_helper(3),
self.nested_tuples_schema_helper(128)))
for i in (1, 2, 3, 128):
# create tuple
created_tuple = self.nested_tuples_creator_helper(i)
# write tuple
s.execute("INSERT INTO mytable (k, v_%s) VALUES (%s, %s)", (i, i, created_tuple))
# verify tuple was written and read correctly
result = s.execute("SELECT v_%s FROM mytable WHERE k=%s", (i, i))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
def test_unicode_query_string(self):
c = Cluster(protocol_version=PROTOCOL_VERSION)

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cassandra.query import dict_factory
try:
import unittest2 as unittest
@@ -22,9 +23,11 @@ log = logging.getLogger(__name__)
from collections import namedtuple
from cassandra.cluster import Cluster
from cassandra.cluster import Cluster, UserTypeDoesNotExist
from tests.integration import get_server_versions, PROTOCOL_VERSION
from tests.integration.datatype_utils import get_sample, get_nonprim_sample,\
DATA_TYPE_PRIMITIVES, DATA_TYPE_NON_PRIMITIVE_NAMES
class TypeTests(unittest.TestCase):
@@ -227,3 +230,257 @@ class TypeTests(unittest.TestCase):
self.assertTrue(type(row.b) is User)
c.shutdown()
def test_udt_sizes(self):
"""
Test for ensuring extra-lengthy udts are handled correctly.
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
MAX_TEST_LENGTH = 16384
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""CREATE KEYSPACE test_udt_sizes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_udt_sizes")
# create the seed udt
s.execute("CREATE TYPE lengthy_udt ({})".format(', '.join(['v_{} int'.format(i) for i in range(MAX_TEST_LENGTH)])))
# create a table with multiple sizes of nested udts
# no need for all nested types, only a spot checked few and the largest one
s.execute("CREATE TABLE mytable ("
"k int PRIMARY KEY, "
"v lengthy_udt)")
# create and register the seed udt type
udt = namedtuple('lengthy_udt', tuple(['v_{}'.format(i) for i in range(MAX_TEST_LENGTH)]))
c.register_user_type("test_udt_sizes", "lengthy_udt", udt)
# verify inserts and reads
for i in (0, 1, 2, 3, MAX_TEST_LENGTH):
# create udt
params = [j for j in range(i)] + [None for j in range(MAX_TEST_LENGTH - i)]
created_udt = udt(*params)
# write udt
s.execute("INSERT INTO mytable (k, v) VALUES (0, %s)", (created_udt,))
# verify udt was written and read correctly
result = s.execute("SELECT v FROM mytable WHERE k=0")[0]
self.assertEqual(created_udt, result.v)
def nested_udt_helper(self, udts, i):
"""
Helper for creating nested udts.
"""
if i == 0:
return udts[0](42, 'Bob')
else:
return udts[i](self.nested_udt_helper(udts, i - 1))
def test_nested_registered_udts(self):
"""
Test for ensuring nested udts are handled correctly.
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
MAX_NESTING_DEPTH = 4 # TODO: Move to 128, or similar
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# set the row_factory to dict_factory for programmatically accessing values
s.row_factory = dict_factory
s.execute("""CREATE KEYSPACE test_nested_unregistered_udts
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_nested_unregistered_udts")
# create the seed udt
s.execute("CREATE TYPE depth_0 (age int, name text)")
# create the nested udts
for i in range(MAX_NESTING_DEPTH):
s.execute("CREATE TYPE depth_{} (value depth_{})".format(i + 1, i))
# create a table with multiple sizes of nested udts
# no need for all nested types, only a spot checked few and the largest one
s.execute("CREATE TABLE mytable ("
"k int PRIMARY KEY, "
"v_0 depth_0, "
"v_1 depth_1, "
"v_2 depth_2, "
"v_3 depth_3, "
"v_{0} depth_{0})".format(MAX_NESTING_DEPTH))
# create the udt container
udts = []
# create and register the seed udt type
udt = namedtuple('depth_0', ('age', 'name'))
udts.append(udt)
c.register_user_type("test_nested_unregistered_udts", "depth_0", udts[0])
# create and register the nested udt types
for i in range(MAX_NESTING_DEPTH):
udt = namedtuple('depth_{}'.format(i + 1), ('value'))
udts.append(udt)
c.register_user_type("test_nested_unregistered_udts", "depth_{}".format(i + 1), udts[i + 1])
# verify inserts and reads
for i in (0, 1, 2, 3, MAX_NESTING_DEPTH):
# create udt
udt = self.nested_udt_helper(udts, i)
# write udt
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, udt))
# verify udt was written and read correctly
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
self.assertEqual(udt, result['v_%s' % i])
def test_nested_unregistered_udts(self):
"""
Test for ensuring nested unregistered udts are handled correctly.
"""
# close copy to test_nested_registered_udts
pass
def test_nested_registered_udts_with_different_namedtuples(self):
"""
Test for ensuring nested udts are handled correctly when the
created namedtuples are use names that are different the cql type.
"""
# close copy to test_nested_registered_udts
pass
def test_non_existing_types(self):
c = Cluster(protocol_version=PROTOCOL_VERSION)
c.connect()
User = namedtuple('user', ('age', 'name'))
self.assertRaises(UserTypeDoesNotExist, c.register_user_type, "some_bad_keyspace", "user", User)
self.assertRaises(UserTypeDoesNotExist, c.register_user_type, "system", "user", User)
def test_primitive_datatypes(self):
"""
Test for inserting various types of DATA_TYPE_PRIMITIVES into UDT's
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# create keyspace
s.execute("""
CREATE KEYSPACE test_primitive_datatypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("test_primitive_datatypes")
# create UDT
alpha_type_list = []
start_index = ord('a')
for i, datatype in enumerate(DATA_TYPE_PRIMITIVES):
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
s.execute("""
CREATE TYPE alldatatypes ({0})
""".format(', '.join(alpha_type_list))
)
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b alldatatypes)")
# register UDT
alphabet_list = []
for i in range(ord('a'), ord('a') + len(DATA_TYPE_PRIMITIVES)):
alphabet_list.append('{}'.format(chr(i)))
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
c.register_user_type("test_primitive_datatypes", "alldatatypes", Alldatatypes)
# insert UDT data
params = []
for datatype in DATA_TYPE_PRIMITIVES:
params.append((get_sample(datatype)))
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(insert, (0, Alldatatypes(*params)))
# retrieve and verify data
results = s.execute("SELECT * FROM mytable")
self.assertEqual(1, len(results))
row = results[0].b
for expected, actual in zip(params, row):
self.assertEqual(expected, actual)
c.shutdown()
def test_nonprimitive_datatypes(self):
"""
Test for inserting various types of DATA_TYPE_NON_PRIMITIVE into UDT's
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# create keyspace
s.execute("""
CREATE KEYSPACE test_nonprimitive_datatypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("test_nonprimitive_datatypes")
# create UDT
alpha_type_list = []
start_index = ord('a')
for i, nonprim_datatype in enumerate(DATA_TYPE_NON_PRIMITIVE_NAMES):
for j, datatype in enumerate(DATA_TYPE_PRIMITIVES):
if nonprim_datatype == "map":
type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j),
nonprim_datatype, datatype)
else:
type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j),
nonprim_datatype, datatype)
alpha_type_list.append(type_string)
s.execute("""
CREATE TYPE alldatatypes ({0})
""".format(', '.join(alpha_type_list))
)
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b alldatatypes)")
# register UDT
alphabet_list = []
for i in range(ord('a'), ord('a') + len(DATA_TYPE_NON_PRIMITIVE_NAMES)):
for j in range(ord('a'), ord('a') + len(DATA_TYPE_PRIMITIVES)):
alphabet_list.append('{0}_{1}'.format(chr(i), chr(j)))
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
c.register_user_type("test_nonprimitive_datatypes", "alldatatypes", Alldatatypes)
# insert UDT data
params = []
for nonprim_datatype in DATA_TYPE_NON_PRIMITIVE_NAMES:
for datatype in DATA_TYPE_PRIMITIVES:
params.append((get_nonprim_sample(nonprim_datatype, datatype)))
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(insert, (0, Alldatatypes(*params)))
# retrieve and verify data
results = s.execute("SELECT * FROM mytable")
self.assertEqual(1, len(results))
row = results[0].b
for expected, actual in zip(params, row):
self.assertEqual(expected, actual)
c.shutdown()

View File

@@ -240,21 +240,3 @@ class HostConnectionPoolTests(unittest.TestCase):
self.assertEqual(a, b, 'Two Host instances should be equal when sharing.')
self.assertNotEqual(a, c, 'Two Host instances should NOT be equal when using two different addresses.')
self.assertNotEqual(b, c, 'Two Host instances should NOT be equal when using two different addresses.')
class HostTests(unittest.TestCase):
def test_version_parsing(self):
host = Host('127.0.0.1', SimpleConvictionPolicy)
host.set_version("1.0.0")
self.assertEqual((1, 0, 0), host.version)
host.set_version("1.0")
self.assertEqual((1, 0, 0), host.version)
host.set_version("1.0.0-beta1")
self.assertEqual((1, 0, 0, 'beta1'), host.version)
host.set_version("1.0-SNAPSHOT")
self.assertEqual((1, 0, 0, 'SNAPSHOT'), host.version)

View File

@@ -15,14 +15,18 @@
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import unittest # noqa
from mock import Mock
import cassandra
from cassandra.cqltypes import IntegerType, AsciiType, TupleType
from cassandra.metadata import (Murmur3Token, MD5Token,
BytesToken, ReplicationStrategy,
NetworkTopologyStrategy, SimpleStrategy,
LocalStrategy, NoMurmur3, protect_name,
protect_names, protect_value, is_valid_name)
protect_names, protect_value, is_valid_name,
UserType, KeyspaceMetadata)
from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host
@@ -181,12 +185,12 @@ class TestNameEscaping(unittest.TestCase):
'tests ?!@#$%^&*()',
'1'
]),
[
'tests',
"\"test's\"",
'"tests ?!@#$%^&*()"',
'"1"'
])
[
'tests',
"\"test's\"",
'"tests ?!@#$%^&*()"',
'"1"'
])
def test_protect_value(self):
"""
@@ -248,3 +252,63 @@ class TestTokens(unittest.TestCase):
self.fail('Tokens for ByteOrderedPartitioner should be only strings')
except TypeError:
pass
class TestKeyspaceMetadata(unittest.TestCase):
def test_export_as_string_user_types(self):
keyspace_name = 'test'
keyspace = KeyspaceMetadata(keyspace_name, True, 'SimpleStrategy', dict(replication_factor=3))
keyspace.user_types['a'] = UserType(keyspace_name, 'a', ['one', 'two'],
[self.mock_user_type('UserType', 'c'),
self.mock_user_type('IntType', 'int')])
keyspace.user_types['b'] = UserType(keyspace_name, 'b', ['one', 'two', 'three'],
[self.mock_user_type('UserType', 'd'),
self.mock_user_type('IntType', 'int'),
self.mock_user_type('UserType', 'a')])
keyspace.user_types['c'] = UserType(keyspace_name, 'c', ['one'],
[self.mock_user_type('IntType', 'int')])
keyspace.user_types['d'] = UserType(keyspace_name, 'd', ['one'],
[self.mock_user_type('UserType', 'c')])
self.assertEqual("""CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} AND durable_writes = true;
CREATE TYPE test.c (
one int
);
CREATE TYPE test.a (
one c,
two int
);
CREATE TYPE test.d (
one c
);
CREATE TYPE test.b (
one d,
two int,
three a
);""", keyspace.export_as_string())
def mock_user_type(self, cassname, typename):
return Mock(**{'cassname': cassname, 'typename': typename, 'cql_parameterized_type.return_value': typename})
class TestUserTypes(unittest.TestCase):
def test_as_cql_query(self):
field_types = [IntegerType, AsciiType, TupleType.apply_parameters([IntegerType, AsciiType])]
udt = UserType("ks1", "mytype", ["a", "b", "c"], field_types)
self.assertEqual("CREATE TYPE ks1.mytype (a varint, b ascii, c tuple<varint, ascii>);", udt.as_cql_query(formatted=False))
self.assertEqual("""CREATE TYPE ks1.mytype (
a varint,
b ascii,
c tuple<varint, ascii>
);""", udt.as_cql_query(formatted=True))
def test_as_cql_query_name_escaping(self):
udt = UserType("MyKeyspace", "MyType", ["AbA", "keyspace"], [AsciiType, AsciiType])
self.assertEqual('CREATE TYPE "MyKeyspace"."MyType" ("AbA" ascii, "keyspace" ascii);', udt.as_cql_query(formatted=False))

View File

@@ -17,7 +17,7 @@ try:
except ImportError:
import unittest # noqa
from cassandra.encoder import cql_encoders
from cassandra.encoder import Encoder
from cassandra.query import bind_params, ValueSequence
from cassandra.query import PreparedStatement, BoundStatement
from cassandra.cqltypes import Int32Type
@@ -29,31 +29,31 @@ from six.moves import xrange
class ParamBindingTest(unittest.TestCase):
def test_bind_sequence(self):
result = bind_params("%s %s %s", (1, "a", 2.0), cql_encoders)
result = bind_params("%s %s %s", (1, "a", 2.0), Encoder())
self.assertEqual(result, "1 'a' 2.0")
def test_bind_map(self):
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), cql_encoders)
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), Encoder())
self.assertEqual(result, "1 'a' 2.0")
def test_sequence_param(self):
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), cql_encoders)
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), Encoder())
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
def test_generator_param(self):
result = bind_params("%s", ((i for i in xrange(3)),), cql_encoders)
result = bind_params("%s", ((i for i in xrange(3)),), Encoder())
self.assertEqual(result, "[ 0 , 1 , 2 ]")
def test_none_param(self):
result = bind_params("%s", (None,), cql_encoders)
result = bind_params("%s", (None,), Encoder())
self.assertEqual(result, "NULL")
def test_list_collection(self):
result = bind_params("%s", (['a', 'b', 'c'],), cql_encoders)
result = bind_params("%s", (['a', 'b', 'c'],), Encoder())
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
def test_set_collection(self):
result = bind_params("%s", (set(['a', 'b']),), cql_encoders)
result = bind_params("%s", (set(['a', 'b']),), Encoder())
self.assertIn(result, ("{ 'a' , 'b' }", "{ 'b' , 'a' }"))
def test_map_collection(self):
@@ -61,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
vals['a'] = 'a'
vals['b'] = 'b'
vals['c'] = 'c'
result = bind_params("%s", (vals,), cql_encoders)
result = bind_params("%s", (vals,), Encoder())
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
def test_quote_escaping(self):
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), cql_encoders)
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), Encoder())
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")