Refactored TypeTests

This commit is contained in:
Kishan Karunaratne
2015-04-07 13:27:16 -07:00
parent 521ab1a116
commit 73ec60f606
2 changed files with 356 additions and 378 deletions

View File

@@ -11,22 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from decimal import Decimal from decimal import Decimal
import datetime from datetime import datetime, date, time
from uuid import UUID from uuid import uuid1, uuid4
import pytz
try: try:
from blist import sortedset from blist import sortedset
except ImportError: except ImportError:
sortedset = set # noqa sortedset = set # noqa
DATA_TYPE_PRIMITIVES = [ from cassandra.util import OrderedMap
from tests.integration import get_server_versions
PRIMITIVE_DATATYPES = [
'ascii', 'ascii',
'bigint', 'bigint',
'blob', 'blob',
'boolean', 'boolean',
# 'counter', counters are not allowed inside tuples
'decimal', 'decimal',
'double', 'double',
'float', 'float',
@@ -40,22 +44,28 @@ DATA_TYPE_PRIMITIVES = [
'varint', 'varint',
] ]
DATA_TYPE_NON_PRIMITIVE_NAMES = [ COLLECTION_TYPES = [
'list', 'list',
'set', 'set',
'map', 'map',
'tuple'
] ]
def get_sample_data(): def update_datatypes():
""" _cass_version, _cql_version = get_server_versions()
Create a standard set of sample inputs for testing.
"""
if _cass_version >= (2, 1, 0):
COLLECTION_TYPES.append('tuple')
if _cass_version >= (2, 1, 5):
PRIMITIVE_DATATYPES.append('date')
PRIMITIVE_DATATYPES.append('time')
def get_sample_data():
sample_data = {} sample_data = {}
for datatype in DATA_TYPE_PRIMITIVES: for datatype in PRIMITIVE_DATATYPES:
if datatype == 'ascii': if datatype == 'ascii':
sample_data[datatype] = 'ascii' sample_data[datatype] = 'ascii'
@@ -68,10 +78,6 @@ def get_sample_data():
elif datatype == 'boolean': elif datatype == 'boolean':
sample_data[datatype] = True sample_data[datatype] = True
elif datatype == 'counter':
# Not supported in an insert statement
pass
elif datatype == 'decimal': elif datatype == 'decimal':
sample_data[datatype] = Decimal('12.3E+7') sample_data[datatype] = Decimal('12.3E+7')
@@ -91,13 +97,13 @@ def get_sample_data():
sample_data[datatype] = 'text' sample_data[datatype] = 'text'
elif datatype == 'timestamp': elif datatype == 'timestamp':
sample_data[datatype] = datetime.datetime.fromtimestamp(872835240, tz=pytz.timezone('America/New_York')).astimezone(pytz.UTC).replace(tzinfo=None) sample_data[datatype] = datetime(2013, 12, 31, 23, 59, 59, 999000)
elif datatype == 'timeuuid': elif datatype == 'timeuuid':
sample_data[datatype] = UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66') sample_data[datatype] = uuid1()
elif datatype == 'uuid': elif datatype == 'uuid':
sample_data[datatype] = UUID('067e6162-3b6f-4ae2-a171-2470b63dff00') sample_data[datatype] = uuid4()
elif datatype == 'varchar': elif datatype == 'varchar':
sample_data[datatype] = 'varchar' sample_data[datatype] = 'varchar'
@@ -105,35 +111,40 @@ def get_sample_data():
elif datatype == 'varint': elif datatype == 'varint':
sample_data[datatype] = int(str(2147483647) + '000') sample_data[datatype] = int(str(2147483647) + '000')
elif datatype == 'date':
sample_data[datatype] = date(2015, 1, 15)
elif datatype == 'time':
sample_data[datatype] = time(16, 47, 25, 7)
else: else:
raise Exception('Missing handling of %s.' % datatype) raise Exception("Missing handling of {0}".format(datatype))
return sample_data return sample_data
SAMPLE_DATA = get_sample_data() SAMPLE_DATA = get_sample_data()
def get_sample(datatype): def get_sample(datatype):
""" """
Helper method to access created sample data Helper method to access created sample data for primitive types
""" """
return SAMPLE_DATA[datatype] return SAMPLE_DATA[datatype]
def get_nonprim_sample(non_prim_type, datatype):
def get_collection_sample(collection_type, datatype):
""" """
Helper method to access created sample data for non-primitives Helper method to access created sample data for collection types
""" """
if non_prim_type == 'list': if collection_type == 'list':
return [get_sample(datatype), get_sample(datatype)] return [get_sample(datatype), get_sample(datatype)]
elif non_prim_type == 'set': elif collection_type == 'set':
return sortedset([get_sample(datatype)]) return sortedset([get_sample(datatype)])
elif non_prim_type == 'map': elif collection_type == 'map':
if datatype == 'blob': return OrderedMap([(get_sample(datatype), get_sample(datatype))])
return {get_sample('ascii'): get_sample(datatype)} elif collection_type == 'tuple':
else:
return {get_sample(datatype): get_sample(datatype)}
elif non_prim_type == 'tuple':
return (get_sample(datatype),) return (get_sample(datatype),)
else: else:
raise Exception('Missing handling of non-primitive type {0}.'.format(non_prim_type)) raise Exception('Missing handling of non-primitive type {0}.'.format(collection_type))

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from tests.integration.datatype_utils import get_sample, DATA_TYPE_PRIMITIVES, DATA_TYPE_NON_PRIMITIVE_NAMES
try: try:
import unittest2 as unittest import unittest2 as unittest
@@ -21,85 +20,53 @@ except ImportError:
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
from collections import namedtuple from datetime import datetime
from decimal import Decimal
from datetime import datetime, date, time
from functools import partial
import six import six
from uuid import uuid1, uuid4
from cassandra import InvalidRequest from cassandra import InvalidRequest
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.cqltypes import Int32Type, EMPTY from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.query import dict_factory from cassandra.query import dict_factory, ordered_dict_factory
from cassandra.util import OrderedMap, sortedset from cassandra.util import sortedset
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \
# defined in module scope for pickling in OrderedMap get_sample, get_collection_sample
nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's'])
nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u'])
def setup_module(): def setup_module():
use_singledc() use_singledc()
update_datatypes()
class TypeTests(unittest.TestCase): class TypeTests(unittest.TestCase):
_types_table_created = False def setUp(self):
self._cass_version, self._cql_version = get_server_versions()
@classmethod self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
def setup_class(cls): self.session = self.cluster.connect()
cls._cass_version, cls._cql_version = get_server_versions() self.session.execute("CREATE KEYSPACE typetests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
self.cluster.shutdown()
cls._col_types = ['text', def tearDown(self):
'ascii', self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
'bigint', self.session = self.cluster.connect()
'boolean', self.session.execute("DROP KEYSPACE typetests")
'decimal', self.cluster.shutdown()
'double',
'float',
'inet',
'int',
'list<text>',
'set<int>',
'map<text,int>',
'timestamp',
'uuid',
'timeuuid',
'varchar',
'varint']
if cls._cass_version >= (2, 1, 4): def test_can_insert_blob_type_as_string(self):
cls._col_types.extend(('date', 'time')) """
Tests that blob type in Cassandra does not map to string in Python
"""
cls._session = Cluster(protocol_version=PROTOCOL_VERSION).connect() c = Cluster(protocol_version=PROTOCOL_VERSION)
cls._session.execute("CREATE KEYSPACE typetests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}") s = c.connect("typetests")
cls._session.set_keyspace("typetests")
@classmethod s.execute("CREATE TABLE blobstring (a ascii PRIMARY KEY, b blob)")
def teardown_class(cls):
cls._session.execute("DROP KEYSPACE typetests")
cls._session.cluster.shutdown()
def test_blob_type_as_string(self): params = ['key1', b'blobyblob']
s = self._session query = "INSERT INTO blobstring (a, b) VALUES (%s, %s)"
s.execute("""
CREATE TABLE blobstring (
a ascii,
b blob,
PRIMARY KEY (a)
)
""")
params = [
'key1',
b'blobyblob'
]
query = 'INSERT INTO blobstring (a, b) VALUES (%s, %s)'
# In python 3, the 'bytes' type is treated as a blob, so we can # In python 3, the 'bytes' type is treated as a blob, so we can
# correctly encode it with hex notation. # correctly encode it with hex notation.
@@ -118,243 +85,277 @@ class TypeTests(unittest.TestCase):
params[1] = params[1].encode('hex') params[1] = params[1].encode('hex')
s.execute(query, params) s.execute(query, params)
expected_vals = [
'key1',
bytearray(b'blobyblob')
]
results = s.execute("SELECT * FROM blobstring") results = s.execute("SELECT * FROM blobstring")[0]
for expected, actual in zip(params, results):
for expected, actual in zip(expected_vals, results[0]):
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
def test_blob_type_as_bytearray(self): c.shutdown()
s = self._session
s.execute("""
CREATE TABLE blobbytes (
a ascii,
b blob,
PRIMARY KEY (a)
)
""")
params = [ def test_can_insert_blob_type_as_bytearray(self):
'key1', """
bytearray(b'blob1') Tests that blob type in Cassandra maps to bytearray in Python
] """
query = 'INSERT INTO blobbytes (a, b) VALUES (%s, %s);' c = Cluster(protocol_version=PROTOCOL_VERSION)
s.execute(query, params) s = c.connect("typetests")
expected_vals = [ s.execute("CREATE TABLE blobbytes (a ascii PRIMARY KEY, b blob)")
'key1',
bytearray(b'blob1')
]
results = s.execute("SELECT * FROM blobbytes") params = ['key1', bytearray(b'blob1')]
s.execute("INSERT INTO blobbytes (a, b) VALUES (%s, %s)", params)
for expected, actual in zip(expected_vals, results[0]): results = s.execute("SELECT * FROM blobbytes")[0]
for expected, actual in zip(params, results):
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
def _create_all_types_table(self): c.shutdown()
if not self._types_table_created:
TypeTests._col_names = ["%s_col" % col_type.translate(None, '<> ,') for col_type in self._col_types]
cql = "CREATE TABLE alltypes ( key int PRIMARY KEY, %s)" % ','.join("%s %s" % name_type for name_type in zip(self._col_names, self._col_types))
self._session.execute(cql)
TypeTests._types_table_created = True
def test_basic_types(self): def test_can_insert_primitive_datatypes(self):
"""
Test insertion of all datatype primitives
"""
s = self._session.cluster.connect() c = Cluster(protocol_version=PROTOCOL_VERSION)
s.set_keyspace(self._session.keyspace) s = c.connect("typetests")
self._create_all_types_table() # create table
alpha_type_list = ["zz int PRIMARY KEY"]
col_names = ["zz"]
start_index = ord('a')
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
col_names.append(chr(start_index + i))
v1_uuid = uuid1() s.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list)))
v4_uuid = uuid4()
mydatetime = datetime(2013, 12, 31, 23, 59, 59, 999000)
# this could use some rework tying column types to names (instead of relying on position) # create the input
params = [ params = [0]
"text", for datatype in PRIMITIVE_DATATYPES:
"ascii", params.append((get_sample(datatype)))
12345678923456789, # bigint
True, # boolean
Decimal('1.234567890123456789'), # decimal
0.000244140625, # double
1.25, # float
"1.2.3.4", # inet
12345, # int
['a', 'b', 'c'], # list<text> collection
set([1, 2, 3]), # set<int> collection
{'a': 1, 'b': 2}, # map<text, int> collection
mydatetime, # timestamp
v4_uuid, # uuid
v1_uuid, # timeuuid
u"sometext\u1234", # varchar
123456789123456789123456789, # varint
]
expected_vals = [ # insert into table as a simple statement
"text", columns_string = ', '.join(col_names)
"ascii", placeholders = ', '.join(["%s"] * len(col_names))
12345678923456789, # bigint s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params)
True, # boolean
Decimal('1.234567890123456789'), # decimal
0.000244140625, # double
1.25, # float
"1.2.3.4", # inet
12345, # int
['a', 'b', 'c'], # list<text> collection
sortedset((1, 2, 3)), # set<int> collection
{'a': 1, 'b': 2}, # map<text, int> collection
mydatetime, # timestamp
v4_uuid, # uuid
v1_uuid, # timeuuid
u"sometext\u1234", # varchar
123456789123456789123456789, # varint
]
if self._cass_version >= (2, 1, 4): # verify data
mydate = date(2015, 1, 15) results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0]
mytime = time(16, 47, 25, 7) for expected, actual in zip(params, results):
params.append(mydate)
params.append(mytime)
expected_vals.append(mydate)
expected_vals.append(mytime)
columns_string = ','.join(self._col_names)
placeholders = ', '.join(["%s"] * len(self._col_names))
s.execute("INSERT INTO alltypes (key, %s) VALUES (0, %s)" %
(columns_string, placeholders), params)
results = s.execute("SELECT %s FROM alltypes WHERE key=0" % columns_string)
for expected, actual in zip(expected_vals, results[0]):
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
# try the same thing with a prepared statement # try the same thing with a prepared statement
placeholders = ','.join(["?"] * len(self._col_names)) placeholders = ','.join(["?"] * len(col_names))
prepared = s.prepare("INSERT INTO alltypes (key, %s) VALUES (1, %s)" % insert = s.prepare("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders))
(columns_string, placeholders)) s.execute(insert.bind(params))
s.execute(prepared.bind(params))
results = s.execute("SELECT %s FROM alltypes WHERE key=1" % columns_string) # verify data
results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0]
for expected, actual in zip(expected_vals, results[0]): for expected, actual in zip(params, results):
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
# query with prepared statement # verify data with prepared statement query
prepared = s.prepare("SELECT %s FROM alltypes WHERE key=?" % columns_string) select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string))
results = s.execute(prepared.bind((1,))) results = s.execute(select.bind([0]))[0]
for expected, actual in zip(params, results):
for expected, actual in zip(expected_vals, results[0]):
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
# query with prepared statement, no explicit columns # verify data with with prepared statement, use dictionary with no explicit columns
s.row_factory = dict_factory s.row_factory = ordered_dict_factory
prepared = s.prepare("SELECT * FROM alltypes") select = s.prepare("SELECT * FROM alltypes")
results = s.execute(prepared.bind(())) results = s.execute(select)[0]
row = results[0] for expected, actual in zip(params, results.values()):
for expected, name in zip(expected_vals, self._col_names): self.assertEqual(actual, expected)
self.assertEqual(row[name], expected)
s.shutdown() c.shutdown()
def test_empty_strings_and_nones(self): def test_can_insert_collection_datatypes(self):
s = self._session.cluster.connect() """
s.set_keyspace(self._session.keyspace) Test insertion of all collection types
s.row_factory = dict_factory """
self._create_all_types_table() c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
# use tuple encoding, to convert native python tuple into raw CQL
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
columns_string = ','.join(self._col_names) # create table
s.execute("INSERT INTO alltypes (key) VALUES (2)") alpha_type_list = ["zz int PRIMARY KEY"]
results = s.execute("SELECT %s FROM alltypes WHERE key=2" % columns_string) col_names = ["zz"]
self.assertTrue(all(x is None for x in results[0].values())) start_index = ord('a')
for i, collection_type in enumerate(COLLECTION_TYPES):
for j, datatype in enumerate(PRIMITIVE_DATATYPES):
if collection_type == "map":
type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j),
collection_type, datatype)
elif collection_type == "tuple":
type_string = "{0}_{1} frozen<{2}<{3}>>".format(chr(start_index + i), chr(start_index + j),
collection_type, datatype)
else:
type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j),
collection_type, datatype)
alpha_type_list.append(type_string)
col_names.append("{0}_{1}".format(chr(start_index + i), chr(start_index + j)))
prepared = s.prepare("SELECT %s FROM alltypes WHERE key=?" % columns_string) s.execute("CREATE TABLE allcoltypes ({0})".format(', '.join(alpha_type_list)))
results = s.execute(prepared.bind((2,))) columns_string = ', '.join(col_names)
self.assertTrue(all(x is None for x in results[0].values()))
# insert empty strings for string-like fields and fetch them # create the input for simple statement
expected_values = {'text_col': '', 'ascii_col': '', 'varchar_col': '', 'listtext_col': [''], 'maptextint_col': OrderedMap({'': 3})} params = [0]
for collection_type in COLLECTION_TYPES:
for datatype in PRIMITIVE_DATATYPES:
params.append((get_collection_sample(collection_type, datatype)))
# insert into table as a simple statement
placeholders = ', '.join(["%s"] * len(col_names))
s.execute("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params)
# verify data
results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string))[0]
for expected, actual in zip(params, results):
self.assertEqual(actual, expected)
# create the input for prepared statement
params = [0]
for collection_type in COLLECTION_TYPES:
for datatype in PRIMITIVE_DATATYPES:
params.append((get_collection_sample(collection_type, datatype)))
# try the same thing with a prepared statement
placeholders = ','.join(["?"] * len(col_names))
insert = s.prepare("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders))
s.execute(insert.bind(params))
# verify data
results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string))[0]
for expected, actual in zip(params, results):
self.assertEqual(actual, expected)
# verify data with prepared statement query
select = s.prepare("SELECT {0} FROM allcoltypes WHERE zz=?".format(columns_string))
results = s.execute(select.bind([0]))[0]
for expected, actual in zip(params, results):
self.assertEqual(actual, expected)
# verify data with with prepared statement, use dictionary with no explicit columns
s.row_factory = ordered_dict_factory
select = s.prepare("SELECT * FROM allcoltypes")
results = s.execute(select)[0]
for expected, actual in zip(params, results.values()):
self.assertEqual(actual, expected)
c.shutdown()
def test_can_insert_empty_strings_and_nulls(self):
"""
Test insertion of empty strings and null values
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
# create table
alpha_type_list = ["zz int PRIMARY KEY"]
col_names = []
start_index = ord('a')
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
col_names.append(chr(start_index + i))
s.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list)))
# verify all types initially null with simple statement
columns_string = ','.join(col_names)
s.execute("INSERT INTO alltypes (zz) VALUES (2)")
results = s.execute("SELECT {0} FROM alltypes WHERE zz=2".format(columns_string))[0]
self.assertTrue(all(x is None for x in results))
# verify all types initially null with prepared statement
select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string))
results = s.execute(select.bind([2]))[0]
self.assertTrue(all(x is None for x in results))
# insert empty strings for string-like fields
expected_values = {'j': '', 'a': '', 'n': ''}
columns_string = ','.join(expected_values.keys()) columns_string = ','.join(expected_values.keys())
placeholders = ','.join(["%s"] * len(expected_values)) placeholders = ','.join(["%s"] * len(expected_values))
s.execute("INSERT INTO alltypes (key, %s) VALUES (3, %s)" % (columns_string, placeholders), expected_values.values()) s.execute("INSERT INTO alltypes (zz, {0}) VALUES (3, {1})".format(columns_string, placeholders), expected_values.values())
self.assertEqual(expected_values,
s.execute("SELECT %s FROM alltypes WHERE key=3" % columns_string)[0]) # verify string types empty with simple statement
self.assertEqual(expected_values, results = s.execute("SELECT {0} FROM alltypes WHERE zz=3".format(columns_string))[0]
s.execute(s.prepare("SELECT %s FROM alltypes WHERE key=?" % columns_string), (3,))[0]) for expected, actual in zip(expected_values.values(), results):
self.assertEqual(actual, expected)
# verify string types empty with prepared statement
results = s.execute(s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string)), [3])[0]
for expected, actual in zip(expected_values.values(), results):
self.assertEqual(actual, expected)
# non-string types shouldn't accept empty strings # non-string types shouldn't accept empty strings
for col in ('bigint_col', 'boolean_col', 'decimal_col', 'double_col', for col in ('b', 'd', 'e', 'f', 'g', 'i', 'l', 'm', 'o'):
'float_col', 'int_col', 'listtext_col', 'setint_col', query = "INSERT INTO alltypes (zz, {0}) VALUES (4, %s)".format(col)
'maptextint_col', 'uuid_col', 'timeuuid_col', 'varint_col'): with self.assertRaises(InvalidRequest):
query = "INSERT INTO alltypes (key, %s) VALUES (4, %%s)" % col
try:
s.execute(query, ['']) s.execute(query, [''])
except InvalidRequest:
pass
else:
self.fail("Expected an InvalidRequest error when inserting an "
"emptry string for column %s" % (col, ))
prepared = s.prepare("INSERT INTO alltypes (key, %s) VALUES (4, ?)" % col) insert = s.prepare("INSERT INTO alltypes (zz, {0}) VALUES (4, ?)".format(col))
try: with self.assertRaises(TypeError):
s.execute(prepared, ['']) s.execute(insert, [''])
except TypeError:
pass
else:
self.fail("Expected an InvalidRequest error when inserting an "
"emptry string for column %s with a prepared statement" % (col, ))
# insert values for all columns # verify that Nones can be inserted and overwrites existing data
values = ['text', 'ascii', 1, True, Decimal('1.0'), 0.1, 0.1, # create the input
"1.2.3.4", 1, ['a'], set([1]), {'a': 1}, params = []
datetime.now(), uuid4(), uuid1(), 'a', 1] for datatype in PRIMITIVE_DATATYPES:
if self._cass_version >= (2, 1, 4): params.append((get_sample(datatype)))
values.append('2014-01-01')
values.append('01:02:03.456789012')
columns_string = ','.join(self._col_names) # insert the data
placeholders = ','.join(["%s"] * len(self._col_names)) columns_string = ','.join(col_names)
insert = "INSERT INTO alltypes (key, %s) VALUES (5, %s)" % (columns_string, placeholders) placeholders = ','.join(["%s"] * len(col_names))
s.execute(insert, values) simple_insert = "INSERT INTO alltypes (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders)
s.execute(simple_insert, params)
# then insert None, which should null them out # then insert None, which should null them out
null_values = [None] * len(self._col_names) null_values = [None] * len(col_names)
s.execute(insert, null_values) s.execute(simple_insert, null_values)
select = "SELECT %s FROM alltypes WHERE key=5" % columns_string # check via simple statement
results = s.execute(select) query = "SELECT {0} FROM alltypes WHERE zz=5".format(columns_string)
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) results = s.execute(query)[0]
for col in results:
self.assertEqual(None, col)
prepared = s.prepare(select) # check via prepared statement
results = s.execute(prepared.bind(())) select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string))
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) results = s.execute(select.bind([5]))[0]
for col in results:
self.assertEqual(None, col)
# do the same thing again, but use a prepared statement to insert the nulls # do the same thing again, but use a prepared statement to insert the nulls
s.execute(insert, values) s.execute(simple_insert, params)
placeholders = ','.join(["?"] * len(self._col_names)) placeholders = ','.join(["?"] * len(col_names))
prepared = s.prepare("INSERT INTO alltypes (key, %s) VALUES (5, %s)" % (columns_string, placeholders)) insert = s.prepare("INSERT INTO alltypes (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders))
s.execute(prepared, null_values) s.execute(insert, null_values)
results = s.execute(select) results = s.execute(query)[0]
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) for col in results:
self.assertEqual(None, col)
prepared = s.prepare(select) results = s.execute(select.bind([5]))[0]
results = s.execute(prepared.bind(())) for col in results:
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) self.assertEqual(None, col)
s.shutdown() s.shutdown()
def test_empty_values(self): def test_can_insert_empty_values_for_int32(self):
s = self._session """
Ensure Int32Type supports empty values
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
s.execute("CREATE TABLE empty_values (a text PRIMARY KEY, b int)") s.execute("CREATE TABLE empty_values (a text PRIMARY KEY, b int)")
s.execute("INSERT INTO empty_values (a, b) VALUES ('a', blobAsInt(0x))") s.execute("INSERT INTO empty_values (a, b) VALUES ('a', blobAsInt(0x))")
try: try:
@@ -364,8 +365,13 @@ class TypeTests(unittest.TestCase):
finally: finally:
Int32Type.support_empty_values = False Int32Type.support_empty_values = False
def test_timezone_aware_datetimes(self): c.shutdown()
""" Ensure timezone-aware datetimes are converted to timestamps correctly """
def test_timezone_aware_datetimes_are_timestamps(self):
"""
Ensure timezone-aware datetimes are converted to timestamps correctly
"""
try: try:
import pytz import pytz
except ImportError as exc: except ImportError as exc:
@@ -375,22 +381,25 @@ class TypeTests(unittest.TestCase):
eastern_tz = pytz.timezone('US/Eastern') eastern_tz = pytz.timezone('US/Eastern')
eastern_tz.localize(dt) eastern_tz.localize(dt)
s = self._session c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
s.execute("CREATE TABLE tz_aware (a ascii PRIMARY KEY, b timestamp)") s.execute("CREATE TABLE tz_aware (a ascii PRIMARY KEY, b timestamp)")
# test non-prepared statement # test non-prepared statement
s.execute("INSERT INTO tz_aware (a, b) VALUES ('key1', %s)", parameters=(dt,)) s.execute("INSERT INTO tz_aware (a, b) VALUES ('key1', %s)", [dt])
result = s.execute("SELECT b FROM tz_aware WHERE a='key1'")[0].b result = s.execute("SELECT b FROM tz_aware WHERE a='key1'")[0].b
self.assertEqual(dt.utctimetuple(), result.utctimetuple()) self.assertEqual(dt.utctimetuple(), result.utctimetuple())
# test prepared statement # test prepared statement
prepared = s.prepare("INSERT INTO tz_aware (a, b) VALUES ('key2', ?)") insert = s.prepare("INSERT INTO tz_aware (a, b) VALUES ('key2', ?)")
s.execute(prepared, parameters=(dt,)) s.execute(insert.bind([dt]))
result = s.execute("SELECT b FROM tz_aware WHERE a='key2'")[0].b result = s.execute("SELECT b FROM tz_aware WHERE a='key2'")[0].b
self.assertEqual(dt.utctimetuple(), result.utctimetuple()) self.assertEqual(dt.utctimetuple(), result.utctimetuple())
def test_tuple_type(self): c.shutdown()
def test_can_insert_tuples(self):
""" """
Basic test of tuple functionality Basic test of tuple functionality
""" """
@@ -398,8 +407,8 @@ class TypeTests(unittest.TestCase):
if self._cass_version < (2, 1, 0): if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session.cluster.connect() c = Cluster(protocol_version=PROTOCOL_VERSION)
s.set_keyspace(self._session.keyspace) s = c.connect("typetests")
# use this encoder in order to insert tuples # use this encoder in order to insert tuples
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
@@ -439,9 +448,9 @@ class TypeTests(unittest.TestCase):
self.assertEqual(partial_result, s.execute(prepared, (4,))[0].b) self.assertEqual(partial_result, s.execute(prepared, (4,))[0].b)
self.assertEqual(subpartial_result, s.execute(prepared, (5,))[0].b) self.assertEqual(subpartial_result, s.execute(prepared, (5,))[0].b)
s.shutdown() c.shutdown()
def test_tuple_type_varying_lengths(self): def test_can_insert_tuples_with_varying_lengths(self):
""" """
Test tuple types of lengths of 1, 2, 3, and 384 to ensure edge cases work Test tuple types of lengths of 1, 2, 3, and 384 to ensure edge cases work
as expected. as expected.
@@ -450,8 +459,8 @@ class TypeTests(unittest.TestCase):
if self._cass_version < (2, 1, 0): if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session.cluster.connect() c = Cluster(protocol_version=PROTOCOL_VERSION)
s.set_keyspace(self._session.keyspace) s = c.connect("typetests")
# set the row_factory to dict_factory for programmatic access # set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples # set the encoder for tuples for the ability to write tuples
@@ -479,9 +488,9 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT v_%s FROM tuple_lengths WHERE k=0", (i,))[0] result = s.execute("SELECT v_%s FROM tuple_lengths WHERE k=0", (i,))[0]
self.assertEqual(tuple(created_tuple), result['v_%s' % i]) self.assertEqual(tuple(created_tuple), result['v_%s' % i])
s.shutdown() c.shutdown()
def test_tuple_primitive_subtypes(self): def test_can_insert_tuples_all_primitive_datatypes(self):
""" """
Ensure tuple subtypes are appropriately handled. Ensure tuple subtypes are appropriately handled.
""" """
@@ -489,28 +498,28 @@ class TypeTests(unittest.TestCase):
if self._cass_version < (2, 1, 0): if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session.cluster.connect() c = Cluster(protocol_version=PROTOCOL_VERSION)
s.set_keyspace(self._session.keyspace) s = c.connect("typetests")
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("CREATE TABLE tuple_primitive (" s.execute("CREATE TABLE tuple_primitive ("
"k int PRIMARY KEY, " "k int PRIMARY KEY, "
"v frozen<tuple<%s>>)" % ','.join(DATA_TYPE_PRIMITIVES)) "v frozen<tuple<%s>>)" % ','.join(PRIMITIVE_DATATYPES))
for i in range(len(DATA_TYPE_PRIMITIVES)): for i in range(len(PRIMITIVE_DATATYPES)):
# create tuples to be written and ensure they match with the expected response # create tuples to be written and ensure they match with the expected response
# responses have trailing None values for every element that has not been written # responses have trailing None values for every element that has not been written
created_tuple = [get_sample(DATA_TYPE_PRIMITIVES[j]) for j in range(i + 1)] created_tuple = [get_sample(PRIMITIVE_DATATYPES[j]) for j in range(i + 1)]
response_tuple = tuple(created_tuple + [None for j in range(len(DATA_TYPE_PRIMITIVES) - i - 1)]) response_tuple = tuple(created_tuple + [None for j in range(len(PRIMITIVE_DATATYPES) - i - 1)])
written_tuple = tuple(created_tuple) written_tuple = tuple(created_tuple)
s.execute("INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, written_tuple)) s.execute("INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, written_tuple))
result = s.execute("SELECT v FROM tuple_primitive WHERE k=%s", (i,))[0] result = s.execute("SELECT v FROM tuple_primitive WHERE k=%s", (i,))[0]
self.assertEqual(response_tuple, result.v) self.assertEqual(response_tuple, result.v)
s.shutdown() c.shutdown()
def test_tuple_non_primitive_subtypes(self): def test_can_insert_tuples_all_collection_datatypes(self):
""" """
Ensure tuple subtypes are appropriately handled for maps, sets, and lists. Ensure tuple subtypes are appropriately handled for maps, sets, and lists.
""" """
@@ -518,8 +527,8 @@ class TypeTests(unittest.TestCase):
if self._cass_version < (2, 1, 0): if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session.cluster.connect() c = Cluster(protocol_version=PROTOCOL_VERSION)
s.set_keyspace(self._session.keyspace) s = c.connect("typetests")
# set the row_factory to dict_factory for programmatic access # set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples # set the encoder for tuples for the ability to write tuples
@@ -529,15 +538,15 @@ class TypeTests(unittest.TestCase):
values = [] values = []
# create list values # create list values
for datatype in DATA_TYPE_PRIMITIVES: for datatype in PRIMITIVE_DATATYPES:
values.append('v_{} frozen<tuple<list<{}>>>'.format(len(values), datatype)) values.append('v_{} frozen<tuple<list<{}>>>'.format(len(values), datatype))
# create set values # create set values
for datatype in DATA_TYPE_PRIMITIVES: for datatype in PRIMITIVE_DATATYPES:
values.append('v_{} frozen<tuple<set<{}>>>'.format(len(values), datatype)) values.append('v_{} frozen<tuple<set<{}>>>'.format(len(values), datatype))
# create map values # create map values
for datatype in DATA_TYPE_PRIMITIVES: for datatype in PRIMITIVE_DATATYPES:
datatype_1 = datatype_2 = datatype datatype_1 = datatype_2 = datatype
if datatype == 'blob': if datatype == 'blob':
# unhashable type: 'bytearray' # unhashable type: 'bytearray'
@@ -545,9 +554,9 @@ class TypeTests(unittest.TestCase):
values.append('v_{} frozen<tuple<map<{}, {}>>>'.format(len(values), datatype_1, datatype_2)) values.append('v_{} frozen<tuple<map<{}, {}>>>'.format(len(values), datatype_1, datatype_2))
# make sure we're testing all non primitive data types in the future # make sure we're testing all non primitive data types in the future
if set(DATA_TYPE_NON_PRIMITIVE_NAMES) != set(['tuple', 'list', 'map', 'set']): if set(COLLECTION_TYPES) != set(['tuple', 'list', 'map', 'set']):
raise NotImplemented('Missing datatype not implemented: {}'.format( raise NotImplemented('Missing datatype not implemented: {}'.format(
set(DATA_TYPE_NON_PRIMITIVE_NAMES) - set(['tuple', 'list', 'map', 'set']) set(COLLECTION_TYPES) - set(['tuple', 'list', 'map', 'set'])
)) ))
# create table # create table
@@ -557,7 +566,7 @@ class TypeTests(unittest.TestCase):
i = 0 i = 0
# test tuple<list<datatype>> # test tuple<list<datatype>>
for datatype in DATA_TYPE_PRIMITIVES: for datatype in PRIMITIVE_DATATYPES:
created_tuple = tuple([[get_sample(datatype)]]) created_tuple = tuple([[get_sample(datatype)]])
s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple))
@@ -566,7 +575,7 @@ class TypeTests(unittest.TestCase):
i += 1 i += 1
# test tuple<set<datatype>> # test tuple<set<datatype>>
for datatype in DATA_TYPE_PRIMITIVES: for datatype in PRIMITIVE_DATATYPES:
created_tuple = tuple([sortedset([get_sample(datatype)])]) created_tuple = tuple([sortedset([get_sample(datatype)])])
s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple))
@@ -575,7 +584,7 @@ class TypeTests(unittest.TestCase):
i += 1 i += 1
# test tuple<map<datatype, datatype>> # test tuple<map<datatype, datatype>>
for datatype in DATA_TYPE_PRIMITIVES: for datatype in PRIMITIVE_DATATYPES:
if datatype == 'blob': if datatype == 'blob':
# unhashable type: 'bytearray' # unhashable type: 'bytearray'
created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}]) created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}])
@@ -587,7 +596,7 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,))[0] result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,))[0]
self.assertEqual(created_tuple, result['v_%s' % i]) self.assertEqual(created_tuple, result['v_%s' % i])
i += 1 i += 1
s.shutdown() c.shutdown()
def nested_tuples_schema_helper(self, depth): def nested_tuples_schema_helper(self, depth):
""" """
@@ -609,7 +618,7 @@ class TypeTests(unittest.TestCase):
else: else:
return (self.nested_tuples_creator_helper(depth - 1), ) return (self.nested_tuples_creator_helper(depth - 1), )
def test_nested_tuples(self): def test_can_insert_nested_tuples(self):
""" """
Ensure nested are appropriately handled. Ensure nested are appropriately handled.
""" """
@@ -617,8 +626,8 @@ class TypeTests(unittest.TestCase):
if self._cass_version < (2, 1, 0): if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session.cluster.connect() c = Cluster(protocol_version=PROTOCOL_VERSION)
s.set_keyspace(self._session.keyspace) s = c.connect("typetests")
# set the row_factory to dict_factory for programmatic access # set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples # set the encoder for tuples for the ability to write tuples
@@ -647,16 +656,18 @@ class TypeTests(unittest.TestCase):
# verify tuple was written and read correctly # verify tuple was written and read correctly
result = s.execute("SELECT v_%s FROM nested_tuples WHERE k=%s", (i, i))[0] result = s.execute("SELECT v_%s FROM nested_tuples WHERE k=%s", (i, i))[0]
self.assertEqual(created_tuple, result['v_%s' % i]) self.assertEqual(created_tuple, result['v_%s' % i])
s.shutdown() c.shutdown()
def test_tuples_with_nulls(self): def test_can_insert_tuples_with_nulls(self):
""" """
Test tuples with null and empty string fields. Test tuples with null and empty string fields.
""" """
if self._cass_version < (2, 1, 0): if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
s.execute("CREATE TABLE tuples_nulls (k int PRIMARY KEY, t frozen<tuple<text, int, uuid, blob>>)") s.execute("CREATE TABLE tuples_nulls (k int PRIMARY KEY, t frozen<tuple<text, int, uuid, blob>>)")
@@ -675,75 +686,29 @@ class TypeTests(unittest.TestCase):
self.assertEqual(('', None, None, b''), result[0].t) self.assertEqual(('', None, None, b''), result[0].t)
self.assertEqual(('', None, None, b''), s.execute(read)[0].t) self.assertEqual(('', None, None, b''), s.execute(read)[0].t)
def test_unicode_query_string(self): c.shutdown()
s = self._session
def test_can_insert_unicode_query_string(self):
"""
Test to ensure unicode strings can be used in a query
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
query = u"SELECT * FROM system.schema_columnfamilies WHERE keyspace_name = 'ef\u2052ef' AND columnfamily_name = %s" query = u"SELECT * FROM system.schema_columnfamilies WHERE keyspace_name = 'ef\u2052ef' AND columnfamily_name = %s"
s.execute(query, (u"fe\u2051fe",)) s.execute(query, (u"fe\u2051fe",))
def insert_select_column(self, session, table_name, column_name, value): c.shutdown()
insert = session.prepare("INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name))
session.execute(insert, (0, value))
result = session.execute("SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,))[0][0]
self.assertEqual(result, value)
def test_nested_collections(self): def test_can_read_composite_type(self):
"""
Test to ensure that CompositeTypes can be used in a query
"""
if self._cass_version < (2, 1, 3): c = Cluster(protocol_version=PROTOCOL_VERSION)
raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3") s = c.connect("typetests")
if PROTOCOL_VERSION < 3:
raise unittest.SkipTest("Protocol version > 3 required for nested collections")
name = self._testMethodName
s = self._session.cluster.connect()
s.set_keyspace(self._session.keyspace)
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""
CREATE TYPE %s (
m frozen<map<int,text>>,
t tuple<int,text>,
l frozen<list<int>>,
s frozen<set<int>>
)""" % name)
s.execute("""
CREATE TYPE %s_nested (
m frozen<map<int,text>>,
t tuple<int,text>,
l frozen<list<int>>,
s frozen<set<int>>,
u frozen<%s>
)""" % (name, name))
s.execute("""
CREATE TABLE %s (
k int PRIMARY KEY,
map_map map<frozen<map<int,int>>, frozen<map<int,int>>>,
map_set map<frozen<set<int>>, frozen<set<int>>>,
map_list map<frozen<list<int>>, frozen<list<int>>>,
map_tuple map<frozen<tuple<int, int>>, frozen<tuple<int>>>,
map_udt map<frozen<%s_nested>, frozen<%s>>,
)""" % (name, name, name))
validate = partial(self.insert_select_column, s, name)
validate('map_map', OrderedMap([({1: 1, 2: 2}, {3: 3, 4: 4}), ({5: 5, 6: 6}, {7: 7, 8: 8})]))
validate('map_set', OrderedMap([(set((1, 2)), set((3, 4))), (set((5, 6)), set((7, 8)))]))
validate('map_list', OrderedMap([([1, 2], [3, 4]), ([5, 6], [7, 8])]))
validate('map_tuple', OrderedMap([((1, 2), (3,)), ((4, 5), (6,))]))
value = nested_collection_udt({1: 'v1', 2: 'v2'}, (3, 'v3'), [4, 5, 6, 7], set((8, 9, 10)))
key = nested_collection_udt_nested(value.m, value.t, value.l, value.s, value)
key2 = nested_collection_udt_nested({3: 'v3'}, value.t, value.l, value.s, value)
validate('map_udt', OrderedMap([(key, value), (key2, value)]))
s.execute("DROP TABLE %s" % (name))
s.execute("DROP TYPE %s_nested" % (name))
s.execute("DROP TYPE %s" % (name))
s.shutdown()
def test_reading_composite_type(self):
s = self._session
s.execute(""" s.execute("""
CREATE TABLE composites ( CREATE TABLE composites (
a int PRIMARY KEY, a int PRIMARY KEY,
@@ -761,3 +726,5 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT * FROM composites WHERE a = 0")[0] result = s.execute("SELECT * FROM composites WHERE a = 0")[0]
self.assertEqual(0, result.a) self.assertEqual(0, result.a)
self.assertEqual(('abc',), result.b) self.assertEqual(('abc',), result.b)
c.shutdown()