Refactored TypeTests

This commit is contained in:
Kishan Karunaratne
2015-04-07 13:27:16 -07:00
parent 521ab1a116
commit 73ec60f606
2 changed files with 356 additions and 378 deletions

View File

@@ -11,22 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from decimal import Decimal
import datetime
from uuid import UUID
import pytz
from datetime import datetime, date, time
from uuid import uuid1, uuid4
try:
from blist import sortedset
except ImportError:
sortedset = set # noqa
DATA_TYPE_PRIMITIVES = [
from cassandra.util import OrderedMap
from tests.integration import get_server_versions
PRIMITIVE_DATATYPES = [
'ascii',
'bigint',
'blob',
'boolean',
# 'counter', counters are not allowed inside tuples
'decimal',
'double',
'float',
@@ -40,22 +44,28 @@ DATA_TYPE_PRIMITIVES = [
'varint',
]
DATA_TYPE_NON_PRIMITIVE_NAMES = [
COLLECTION_TYPES = [
'list',
'set',
'map',
'tuple'
]
def get_sample_data():
"""
Create a standard set of sample inputs for testing.
"""
def update_datatypes():
_cass_version, _cql_version = get_server_versions()
if _cass_version >= (2, 1, 0):
COLLECTION_TYPES.append('tuple')
if _cass_version >= (2, 1, 5):
PRIMITIVE_DATATYPES.append('date')
PRIMITIVE_DATATYPES.append('time')
def get_sample_data():
sample_data = {}
for datatype in DATA_TYPE_PRIMITIVES:
for datatype in PRIMITIVE_DATATYPES:
if datatype == 'ascii':
sample_data[datatype] = 'ascii'
@@ -68,10 +78,6 @@ def get_sample_data():
elif datatype == 'boolean':
sample_data[datatype] = True
elif datatype == 'counter':
# Not supported in an insert statement
pass
elif datatype == 'decimal':
sample_data[datatype] = Decimal('12.3E+7')
@@ -91,13 +97,13 @@ def get_sample_data():
sample_data[datatype] = 'text'
elif datatype == 'timestamp':
sample_data[datatype] = datetime.datetime.fromtimestamp(872835240, tz=pytz.timezone('America/New_York')).astimezone(pytz.UTC).replace(tzinfo=None)
sample_data[datatype] = datetime(2013, 12, 31, 23, 59, 59, 999000)
elif datatype == 'timeuuid':
sample_data[datatype] = UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66')
sample_data[datatype] = uuid1()
elif datatype == 'uuid':
sample_data[datatype] = UUID('067e6162-3b6f-4ae2-a171-2470b63dff00')
sample_data[datatype] = uuid4()
elif datatype == 'varchar':
sample_data[datatype] = 'varchar'
@@ -105,35 +111,40 @@ def get_sample_data():
elif datatype == 'varint':
sample_data[datatype] = int(str(2147483647) + '000')
elif datatype == 'date':
sample_data[datatype] = date(2015, 1, 15)
elif datatype == 'time':
sample_data[datatype] = time(16, 47, 25, 7)
else:
raise Exception('Missing handling of %s.' % datatype)
raise Exception("Missing handling of {0}".format(datatype))
return sample_data
SAMPLE_DATA = get_sample_data()
def get_sample(datatype):
"""
Helper method to access created sample data
Helper method to access created sample data for primitive types
"""
return SAMPLE_DATA[datatype]
def get_nonprim_sample(non_prim_type, datatype):
def get_collection_sample(collection_type, datatype):
"""
Helper method to access created sample data for non-primitives
Helper method to access created sample data for collection types
"""
if non_prim_type == 'list':
if collection_type == 'list':
return [get_sample(datatype), get_sample(datatype)]
elif non_prim_type == 'set':
elif collection_type == 'set':
return sortedset([get_sample(datatype)])
elif non_prim_type == 'map':
if datatype == 'blob':
return {get_sample('ascii'): get_sample(datatype)}
else:
return {get_sample(datatype): get_sample(datatype)}
elif non_prim_type == 'tuple':
elif collection_type == 'map':
return OrderedMap([(get_sample(datatype), get_sample(datatype))])
elif collection_type == 'tuple':
return (get_sample(datatype),)
else:
raise Exception('Missing handling of non-primitive type {0}.'.format(non_prim_type))
raise Exception('Missing handling of non-primitive type {0}.'.format(collection_type))

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.integration.datatype_utils import get_sample, DATA_TYPE_PRIMITIVES, DATA_TYPE_NON_PRIMITIVE_NAMES
try:
import unittest2 as unittest
@@ -21,85 +20,53 @@ except ImportError:
import logging
log = logging.getLogger(__name__)
from collections import namedtuple
from decimal import Decimal
from datetime import datetime, date, time
from functools import partial
from datetime import datetime
import six
from uuid import uuid1, uuid4
from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.query import dict_factory
from cassandra.util import OrderedMap, sortedset
from cassandra.query import dict_factory, ordered_dict_factory
from cassandra.util import sortedset
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION
# defined in module scope for pickling in OrderedMap
nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's'])
nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u'])
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \
get_sample, get_collection_sample
def setup_module():
use_singledc()
update_datatypes()
class TypeTests(unittest.TestCase):
_types_table_created = False
def setUp(self):
self._cass_version, self._cql_version = get_server_versions()
@classmethod
def setup_class(cls):
cls._cass_version, cls._cql_version = get_server_versions()
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.session = self.cluster.connect()
self.session.execute("CREATE KEYSPACE typetests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
self.cluster.shutdown()
cls._col_types = ['text',
'ascii',
'bigint',
'boolean',
'decimal',
'double',
'float',
'inet',
'int',
'list<text>',
'set<int>',
'map<text,int>',
'timestamp',
'uuid',
'timeuuid',
'varchar',
'varint']
def tearDown(self):
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.session = self.cluster.connect()
self.session.execute("DROP KEYSPACE typetests")
self.cluster.shutdown()
if cls._cass_version >= (2, 1, 4):
cls._col_types.extend(('date', 'time'))
def test_can_insert_blob_type_as_string(self):
"""
Tests that blob type in Cassandra does not map to string in Python
"""
cls._session = Cluster(protocol_version=PROTOCOL_VERSION).connect()
cls._session.execute("CREATE KEYSPACE typetests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
cls._session.set_keyspace("typetests")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
@classmethod
def teardown_class(cls):
cls._session.execute("DROP KEYSPACE typetests")
cls._session.cluster.shutdown()
s.execute("CREATE TABLE blobstring (a ascii PRIMARY KEY, b blob)")
def test_blob_type_as_string(self):
s = self._session
s.execute("""
CREATE TABLE blobstring (
a ascii,
b blob,
PRIMARY KEY (a)
)
""")
params = [
'key1',
b'blobyblob'
]
query = 'INSERT INTO blobstring (a, b) VALUES (%s, %s)'
params = ['key1', b'blobyblob']
query = "INSERT INTO blobstring (a, b) VALUES (%s, %s)"
# In python 3, the 'bytes' type is treated as a blob, so we can
# correctly encode it with hex notation.
@@ -118,243 +85,277 @@ class TypeTests(unittest.TestCase):
params[1] = params[1].encode('hex')
s.execute(query, params)
expected_vals = [
'key1',
bytearray(b'blobyblob')
]
results = s.execute("SELECT * FROM blobstring")
for expected, actual in zip(expected_vals, results[0]):
results = s.execute("SELECT * FROM blobstring")[0]
for expected, actual in zip(params, results):
self.assertEqual(expected, actual)
def test_blob_type_as_bytearray(self):
s = self._session
s.execute("""
CREATE TABLE blobbytes (
a ascii,
b blob,
PRIMARY KEY (a)
)
""")
c.shutdown()
params = [
'key1',
bytearray(b'blob1')
]
def test_can_insert_blob_type_as_bytearray(self):
"""
Tests that blob type in Cassandra maps to bytearray in Python
"""
query = 'INSERT INTO blobbytes (a, b) VALUES (%s, %s);'
s.execute(query, params)
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
expected_vals = [
'key1',
bytearray(b'blob1')
]
s.execute("CREATE TABLE blobbytes (a ascii PRIMARY KEY, b blob)")
results = s.execute("SELECT * FROM blobbytes")
params = ['key1', bytearray(b'blob1')]
s.execute("INSERT INTO blobbytes (a, b) VALUES (%s, %s)", params)
for expected, actual in zip(expected_vals, results[0]):
results = s.execute("SELECT * FROM blobbytes")[0]
for expected, actual in zip(params, results):
self.assertEqual(expected, actual)
def _create_all_types_table(self):
if not self._types_table_created:
TypeTests._col_names = ["%s_col" % col_type.translate(None, '<> ,') for col_type in self._col_types]
cql = "CREATE TABLE alltypes ( key int PRIMARY KEY, %s)" % ','.join("%s %s" % name_type for name_type in zip(self._col_names, self._col_types))
self._session.execute(cql)
TypeTests._types_table_created = True
c.shutdown()
def test_basic_types(self):
def test_can_insert_primitive_datatypes(self):
"""
Test insertion of all datatype primitives
"""
s = self._session.cluster.connect()
s.set_keyspace(self._session.keyspace)
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
self._create_all_types_table()
# create table
alpha_type_list = ["zz int PRIMARY KEY"]
col_names = ["zz"]
start_index = ord('a')
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
col_names.append(chr(start_index + i))
v1_uuid = uuid1()
v4_uuid = uuid4()
mydatetime = datetime(2013, 12, 31, 23, 59, 59, 999000)
s.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list)))
# this could use some rework tying column types to names (instead of relying on position)
params = [
"text",
"ascii",
12345678923456789, # bigint
True, # boolean
Decimal('1.234567890123456789'), # decimal
0.000244140625, # double
1.25, # float
"1.2.3.4", # inet
12345, # int
['a', 'b', 'c'], # list<text> collection
set([1, 2, 3]), # set<int> collection
{'a': 1, 'b': 2}, # map<text, int> collection
mydatetime, # timestamp
v4_uuid, # uuid
v1_uuid, # timeuuid
u"sometext\u1234", # varchar
123456789123456789123456789, # varint
]
# create the input
params = [0]
for datatype in PRIMITIVE_DATATYPES:
params.append((get_sample(datatype)))
expected_vals = [
"text",
"ascii",
12345678923456789, # bigint
True, # boolean
Decimal('1.234567890123456789'), # decimal
0.000244140625, # double
1.25, # float
"1.2.3.4", # inet
12345, # int
['a', 'b', 'c'], # list<text> collection
sortedset((1, 2, 3)), # set<int> collection
{'a': 1, 'b': 2}, # map<text, int> collection
mydatetime, # timestamp
v4_uuid, # uuid
v1_uuid, # timeuuid
u"sometext\u1234", # varchar
123456789123456789123456789, # varint
]
# insert into table as a simple statement
columns_string = ', '.join(col_names)
placeholders = ', '.join(["%s"] * len(col_names))
s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params)
if self._cass_version >= (2, 1, 4):
mydate = date(2015, 1, 15)
mytime = time(16, 47, 25, 7)
params.append(mydate)
params.append(mytime)
expected_vals.append(mydate)
expected_vals.append(mytime)
columns_string = ','.join(self._col_names)
placeholders = ', '.join(["%s"] * len(self._col_names))
s.execute("INSERT INTO alltypes (key, %s) VALUES (0, %s)" %
(columns_string, placeholders), params)
results = s.execute("SELECT %s FROM alltypes WHERE key=0" % columns_string)
for expected, actual in zip(expected_vals, results[0]):
# verify data
results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0]
for expected, actual in zip(params, results):
self.assertEqual(actual, expected)
# try the same thing with a prepared statement
placeholders = ','.join(["?"] * len(self._col_names))
prepared = s.prepare("INSERT INTO alltypes (key, %s) VALUES (1, %s)" %
(columns_string, placeholders))
s.execute(prepared.bind(params))
placeholders = ','.join(["?"] * len(col_names))
insert = s.prepare("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders))
s.execute(insert.bind(params))
results = s.execute("SELECT %s FROM alltypes WHERE key=1" % columns_string)
for expected, actual in zip(expected_vals, results[0]):
# verify data
results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0]
for expected, actual in zip(params, results):
self.assertEqual(actual, expected)
# query with prepared statement
prepared = s.prepare("SELECT %s FROM alltypes WHERE key=?" % columns_string)
results = s.execute(prepared.bind((1,)))
for expected, actual in zip(expected_vals, results[0]):
# verify data with prepared statement query
select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string))
results = s.execute(select.bind([0]))[0]
for expected, actual in zip(params, results):
self.assertEqual(actual, expected)
# query with prepared statement, no explicit columns
s.row_factory = dict_factory
prepared = s.prepare("SELECT * FROM alltypes")
results = s.execute(prepared.bind(()))
# verify data with with prepared statement, use dictionary with no explicit columns
s.row_factory = ordered_dict_factory
select = s.prepare("SELECT * FROM alltypes")
results = s.execute(select)[0]
row = results[0]
for expected, name in zip(expected_vals, self._col_names):
self.assertEqual(row[name], expected)
for expected, actual in zip(params, results.values()):
self.assertEqual(actual, expected)
s.shutdown()
c.shutdown()
def test_empty_strings_and_nones(self):
s = self._session.cluster.connect()
s.set_keyspace(self._session.keyspace)
s.row_factory = dict_factory
def test_can_insert_collection_datatypes(self):
"""
Test insertion of all collection types
"""
self._create_all_types_table()
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
# use tuple encoding, to convert native python tuple into raw CQL
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
columns_string = ','.join(self._col_names)
s.execute("INSERT INTO alltypes (key) VALUES (2)")
results = s.execute("SELECT %s FROM alltypes WHERE key=2" % columns_string)
self.assertTrue(all(x is None for x in results[0].values()))
# create table
alpha_type_list = ["zz int PRIMARY KEY"]
col_names = ["zz"]
start_index = ord('a')
for i, collection_type in enumerate(COLLECTION_TYPES):
for j, datatype in enumerate(PRIMITIVE_DATATYPES):
if collection_type == "map":
type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j),
collection_type, datatype)
elif collection_type == "tuple":
type_string = "{0}_{1} frozen<{2}<{3}>>".format(chr(start_index + i), chr(start_index + j),
collection_type, datatype)
else:
type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j),
collection_type, datatype)
alpha_type_list.append(type_string)
col_names.append("{0}_{1}".format(chr(start_index + i), chr(start_index + j)))
prepared = s.prepare("SELECT %s FROM alltypes WHERE key=?" % columns_string)
results = s.execute(prepared.bind((2,)))
self.assertTrue(all(x is None for x in results[0].values()))
s.execute("CREATE TABLE allcoltypes ({0})".format(', '.join(alpha_type_list)))
columns_string = ', '.join(col_names)
# insert empty strings for string-like fields and fetch them
expected_values = {'text_col': '', 'ascii_col': '', 'varchar_col': '', 'listtext_col': [''], 'maptextint_col': OrderedMap({'': 3})}
# create the input for simple statement
params = [0]
for collection_type in COLLECTION_TYPES:
for datatype in PRIMITIVE_DATATYPES:
params.append((get_collection_sample(collection_type, datatype)))
# insert into table as a simple statement
placeholders = ', '.join(["%s"] * len(col_names))
s.execute("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params)
# verify data
results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string))[0]
for expected, actual in zip(params, results):
self.assertEqual(actual, expected)
# create the input for prepared statement
params = [0]
for collection_type in COLLECTION_TYPES:
for datatype in PRIMITIVE_DATATYPES:
params.append((get_collection_sample(collection_type, datatype)))
# try the same thing with a prepared statement
placeholders = ','.join(["?"] * len(col_names))
insert = s.prepare("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders))
s.execute(insert.bind(params))
# verify data
results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string))[0]
for expected, actual in zip(params, results):
self.assertEqual(actual, expected)
# verify data with prepared statement query
select = s.prepare("SELECT {0} FROM allcoltypes WHERE zz=?".format(columns_string))
results = s.execute(select.bind([0]))[0]
for expected, actual in zip(params, results):
self.assertEqual(actual, expected)
# verify data with with prepared statement, use dictionary with no explicit columns
s.row_factory = ordered_dict_factory
select = s.prepare("SELECT * FROM allcoltypes")
results = s.execute(select)[0]
for expected, actual in zip(params, results.values()):
self.assertEqual(actual, expected)
c.shutdown()
def test_can_insert_empty_strings_and_nulls(self):
"""
Test insertion of empty strings and null values
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
# create table
alpha_type_list = ["zz int PRIMARY KEY"]
col_names = []
start_index = ord('a')
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
col_names.append(chr(start_index + i))
s.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list)))
# verify all types initially null with simple statement
columns_string = ','.join(col_names)
s.execute("INSERT INTO alltypes (zz) VALUES (2)")
results = s.execute("SELECT {0} FROM alltypes WHERE zz=2".format(columns_string))[0]
self.assertTrue(all(x is None for x in results))
# verify all types initially null with prepared statement
select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string))
results = s.execute(select.bind([2]))[0]
self.assertTrue(all(x is None for x in results))
# insert empty strings for string-like fields
expected_values = {'j': '', 'a': '', 'n': ''}
columns_string = ','.join(expected_values.keys())
placeholders = ','.join(["%s"] * len(expected_values))
s.execute("INSERT INTO alltypes (key, %s) VALUES (3, %s)" % (columns_string, placeholders), expected_values.values())
self.assertEqual(expected_values,
s.execute("SELECT %s FROM alltypes WHERE key=3" % columns_string)[0])
self.assertEqual(expected_values,
s.execute(s.prepare("SELECT %s FROM alltypes WHERE key=?" % columns_string), (3,))[0])
s.execute("INSERT INTO alltypes (zz, {0}) VALUES (3, {1})".format(columns_string, placeholders), expected_values.values())
# verify string types empty with simple statement
results = s.execute("SELECT {0} FROM alltypes WHERE zz=3".format(columns_string))[0]
for expected, actual in zip(expected_values.values(), results):
self.assertEqual(actual, expected)
# verify string types empty with prepared statement
results = s.execute(s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string)), [3])[0]
for expected, actual in zip(expected_values.values(), results):
self.assertEqual(actual, expected)
# non-string types shouldn't accept empty strings
for col in ('bigint_col', 'boolean_col', 'decimal_col', 'double_col',
'float_col', 'int_col', 'listtext_col', 'setint_col',
'maptextint_col', 'uuid_col', 'timeuuid_col', 'varint_col'):
query = "INSERT INTO alltypes (key, %s) VALUES (4, %%s)" % col
try:
for col in ('b', 'd', 'e', 'f', 'g', 'i', 'l', 'm', 'o'):
query = "INSERT INTO alltypes (zz, {0}) VALUES (4, %s)".format(col)
with self.assertRaises(InvalidRequest):
s.execute(query, [''])
except InvalidRequest:
pass
else:
self.fail("Expected an InvalidRequest error when inserting an "
"emptry string for column %s" % (col, ))
prepared = s.prepare("INSERT INTO alltypes (key, %s) VALUES (4, ?)" % col)
try:
s.execute(prepared, [''])
except TypeError:
pass
else:
self.fail("Expected an InvalidRequest error when inserting an "
"emptry string for column %s with a prepared statement" % (col, ))
insert = s.prepare("INSERT INTO alltypes (zz, {0}) VALUES (4, ?)".format(col))
with self.assertRaises(TypeError):
s.execute(insert, [''])
# insert values for all columns
values = ['text', 'ascii', 1, True, Decimal('1.0'), 0.1, 0.1,
"1.2.3.4", 1, ['a'], set([1]), {'a': 1},
datetime.now(), uuid4(), uuid1(), 'a', 1]
if self._cass_version >= (2, 1, 4):
values.append('2014-01-01')
values.append('01:02:03.456789012')
# verify that Nones can be inserted and overwrites existing data
# create the input
params = []
for datatype in PRIMITIVE_DATATYPES:
params.append((get_sample(datatype)))
columns_string = ','.join(self._col_names)
placeholders = ','.join(["%s"] * len(self._col_names))
insert = "INSERT INTO alltypes (key, %s) VALUES (5, %s)" % (columns_string, placeholders)
s.execute(insert, values)
# insert the data
columns_string = ','.join(col_names)
placeholders = ','.join(["%s"] * len(col_names))
simple_insert = "INSERT INTO alltypes (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders)
s.execute(simple_insert, params)
# then insert None, which should null them out
null_values = [None] * len(self._col_names)
s.execute(insert, null_values)
null_values = [None] * len(col_names)
s.execute(simple_insert, null_values)
select = "SELECT %s FROM alltypes WHERE key=5" % columns_string
results = s.execute(select)
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
# check via simple statement
query = "SELECT {0} FROM alltypes WHERE zz=5".format(columns_string)
results = s.execute(query)[0]
for col in results:
self.assertEqual(None, col)
prepared = s.prepare(select)
results = s.execute(prepared.bind(()))
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
# check via prepared statement
select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string))
results = s.execute(select.bind([5]))[0]
for col in results:
self.assertEqual(None, col)
# do the same thing again, but use a prepared statement to insert the nulls
s.execute(insert, values)
s.execute(simple_insert, params)
placeholders = ','.join(["?"] * len(self._col_names))
prepared = s.prepare("INSERT INTO alltypes (key, %s) VALUES (5, %s)" % (columns_string, placeholders))
s.execute(prepared, null_values)
placeholders = ','.join(["?"] * len(col_names))
insert = s.prepare("INSERT INTO alltypes (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders))
s.execute(insert, null_values)
results = s.execute(select)
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
results = s.execute(query)[0]
for col in results:
self.assertEqual(None, col)
prepared = s.prepare(select)
results = s.execute(prepared.bind(()))
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
results = s.execute(select.bind([5]))[0]
for col in results:
self.assertEqual(None, col)
s.shutdown()
def test_empty_values(self):
s = self._session
def test_can_insert_empty_values_for_int32(self):
"""
Ensure Int32Type supports empty values
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
s.execute("CREATE TABLE empty_values (a text PRIMARY KEY, b int)")
s.execute("INSERT INTO empty_values (a, b) VALUES ('a', blobAsInt(0x))")
try:
@@ -364,8 +365,13 @@ class TypeTests(unittest.TestCase):
finally:
Int32Type.support_empty_values = False
def test_timezone_aware_datetimes(self):
""" Ensure timezone-aware datetimes are converted to timestamps correctly """
c.shutdown()
def test_timezone_aware_datetimes_are_timestamps(self):
"""
Ensure timezone-aware datetimes are converted to timestamps correctly
"""
try:
import pytz
except ImportError as exc:
@@ -375,22 +381,25 @@ class TypeTests(unittest.TestCase):
eastern_tz = pytz.timezone('US/Eastern')
eastern_tz.localize(dt)
s = self._session
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
s.execute("CREATE TABLE tz_aware (a ascii PRIMARY KEY, b timestamp)")
# test non-prepared statement
s.execute("INSERT INTO tz_aware (a, b) VALUES ('key1', %s)", parameters=(dt,))
s.execute("INSERT INTO tz_aware (a, b) VALUES ('key1', %s)", [dt])
result = s.execute("SELECT b FROM tz_aware WHERE a='key1'")[0].b
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
# test prepared statement
prepared = s.prepare("INSERT INTO tz_aware (a, b) VALUES ('key2', ?)")
s.execute(prepared, parameters=(dt,))
insert = s.prepare("INSERT INTO tz_aware (a, b) VALUES ('key2', ?)")
s.execute(insert.bind([dt]))
result = s.execute("SELECT b FROM tz_aware WHERE a='key2'")[0].b
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
def test_tuple_type(self):
c.shutdown()
def test_can_insert_tuples(self):
"""
Basic test of tuple functionality
"""
@@ -398,8 +407,8 @@ class TypeTests(unittest.TestCase):
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session.cluster.connect()
s.set_keyspace(self._session.keyspace)
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
# use this encoder in order to insert tuples
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
@@ -439,9 +448,9 @@ class TypeTests(unittest.TestCase):
self.assertEqual(partial_result, s.execute(prepared, (4,))[0].b)
self.assertEqual(subpartial_result, s.execute(prepared, (5,))[0].b)
s.shutdown()
c.shutdown()
def test_tuple_type_varying_lengths(self):
def test_can_insert_tuples_with_varying_lengths(self):
"""
Test tuple types of lengths of 1, 2, 3, and 384 to ensure edge cases work
as expected.
@@ -450,8 +459,8 @@ class TypeTests(unittest.TestCase):
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session.cluster.connect()
s.set_keyspace(self._session.keyspace)
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
@@ -479,9 +488,9 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT v_%s FROM tuple_lengths WHERE k=0", (i,))[0]
self.assertEqual(tuple(created_tuple), result['v_%s' % i])
s.shutdown()
c.shutdown()
def test_tuple_primitive_subtypes(self):
def test_can_insert_tuples_all_primitive_datatypes(self):
"""
Ensure tuple subtypes are appropriately handled.
"""
@@ -489,28 +498,28 @@ class TypeTests(unittest.TestCase):
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session.cluster.connect()
s.set_keyspace(self._session.keyspace)
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("CREATE TABLE tuple_primitive ("
"k int PRIMARY KEY, "
"v frozen<tuple<%s>>)" % ','.join(DATA_TYPE_PRIMITIVES))
"v frozen<tuple<%s>>)" % ','.join(PRIMITIVE_DATATYPES))
for i in range(len(DATA_TYPE_PRIMITIVES)):
for i in range(len(PRIMITIVE_DATATYPES)):
# create tuples to be written and ensure they match with the expected response
# responses have trailing None values for every element that has not been written
created_tuple = [get_sample(DATA_TYPE_PRIMITIVES[j]) for j in range(i + 1)]
response_tuple = tuple(created_tuple + [None for j in range(len(DATA_TYPE_PRIMITIVES) - i - 1)])
created_tuple = [get_sample(PRIMITIVE_DATATYPES[j]) for j in range(i + 1)]
response_tuple = tuple(created_tuple + [None for j in range(len(PRIMITIVE_DATATYPES) - i - 1)])
written_tuple = tuple(created_tuple)
s.execute("INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, written_tuple))
result = s.execute("SELECT v FROM tuple_primitive WHERE k=%s", (i,))[0]
self.assertEqual(response_tuple, result.v)
s.shutdown()
c.shutdown()
def test_tuple_non_primitive_subtypes(self):
def test_can_insert_tuples_all_collection_datatypes(self):
"""
Ensure tuple subtypes are appropriately handled for maps, sets, and lists.
"""
@@ -518,8 +527,8 @@ class TypeTests(unittest.TestCase):
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session.cluster.connect()
s.set_keyspace(self._session.keyspace)
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
@@ -529,15 +538,15 @@ class TypeTests(unittest.TestCase):
values = []
# create list values
for datatype in DATA_TYPE_PRIMITIVES:
for datatype in PRIMITIVE_DATATYPES:
values.append('v_{} frozen<tuple<list<{}>>>'.format(len(values), datatype))
# create set values
for datatype in DATA_TYPE_PRIMITIVES:
for datatype in PRIMITIVE_DATATYPES:
values.append('v_{} frozen<tuple<set<{}>>>'.format(len(values), datatype))
# create map values
for datatype in DATA_TYPE_PRIMITIVES:
for datatype in PRIMITIVE_DATATYPES:
datatype_1 = datatype_2 = datatype
if datatype == 'blob':
# unhashable type: 'bytearray'
@@ -545,9 +554,9 @@ class TypeTests(unittest.TestCase):
values.append('v_{} frozen<tuple<map<{}, {}>>>'.format(len(values), datatype_1, datatype_2))
# make sure we're testing all non primitive data types in the future
if set(DATA_TYPE_NON_PRIMITIVE_NAMES) != set(['tuple', 'list', 'map', 'set']):
if set(COLLECTION_TYPES) != set(['tuple', 'list', 'map', 'set']):
raise NotImplemented('Missing datatype not implemented: {}'.format(
set(DATA_TYPE_NON_PRIMITIVE_NAMES) - set(['tuple', 'list', 'map', 'set'])
set(COLLECTION_TYPES) - set(['tuple', 'list', 'map', 'set'])
))
# create table
@@ -557,7 +566,7 @@ class TypeTests(unittest.TestCase):
i = 0
# test tuple<list<datatype>>
for datatype in DATA_TYPE_PRIMITIVES:
for datatype in PRIMITIVE_DATATYPES:
created_tuple = tuple([[get_sample(datatype)]])
s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple))
@@ -566,7 +575,7 @@ class TypeTests(unittest.TestCase):
i += 1
# test tuple<set<datatype>>
for datatype in DATA_TYPE_PRIMITIVES:
for datatype in PRIMITIVE_DATATYPES:
created_tuple = tuple([sortedset([get_sample(datatype)])])
s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple))
@@ -575,7 +584,7 @@ class TypeTests(unittest.TestCase):
i += 1
# test tuple<map<datatype, datatype>>
for datatype in DATA_TYPE_PRIMITIVES:
for datatype in PRIMITIVE_DATATYPES:
if datatype == 'blob':
# unhashable type: 'bytearray'
created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}])
@@ -587,7 +596,7 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
i += 1
s.shutdown()
c.shutdown()
def nested_tuples_schema_helper(self, depth):
"""
@@ -609,7 +618,7 @@ class TypeTests(unittest.TestCase):
else:
return (self.nested_tuples_creator_helper(depth - 1), )
def test_nested_tuples(self):
def test_can_insert_nested_tuples(self):
"""
Ensure nested are appropriately handled.
"""
@@ -617,8 +626,8 @@ class TypeTests(unittest.TestCase):
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session.cluster.connect()
s.set_keyspace(self._session.keyspace)
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
# set the row_factory to dict_factory for programmatic access
# set the encoder for tuples for the ability to write tuples
@@ -647,16 +656,18 @@ class TypeTests(unittest.TestCase):
# verify tuple was written and read correctly
result = s.execute("SELECT v_%s FROM nested_tuples WHERE k=%s", (i, i))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
s.shutdown()
c.shutdown()
def test_tuples_with_nulls(self):
def test_can_insert_tuples_with_nulls(self):
"""
Test tuples with null and empty string fields.
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
s = self._session
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
s.execute("CREATE TABLE tuples_nulls (k int PRIMARY KEY, t frozen<tuple<text, int, uuid, blob>>)")
@@ -675,75 +686,29 @@ class TypeTests(unittest.TestCase):
self.assertEqual(('', None, None, b''), result[0].t)
self.assertEqual(('', None, None, b''), s.execute(read)[0].t)
def test_unicode_query_string(self):
s = self._session
c.shutdown()
def test_can_insert_unicode_query_string(self):
"""
Test to ensure unicode strings can be used in a query
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
query = u"SELECT * FROM system.schema_columnfamilies WHERE keyspace_name = 'ef\u2052ef' AND columnfamily_name = %s"
s.execute(query, (u"fe\u2051fe",))
def insert_select_column(self, session, table_name, column_name, value):
insert = session.prepare("INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name))
session.execute(insert, (0, value))
result = session.execute("SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,))[0][0]
self.assertEqual(result, value)
c.shutdown()
def test_nested_collections(self):
def test_can_read_composite_type(self):
"""
Test to ensure that CompositeTypes can be used in a query
"""
if self._cass_version < (2, 1, 3):
raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect("typetests")
if PROTOCOL_VERSION < 3:
raise unittest.SkipTest("Protocol version > 3 required for nested collections")
name = self._testMethodName
s = self._session.cluster.connect()
s.set_keyspace(self._session.keyspace)
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""
CREATE TYPE %s (
m frozen<map<int,text>>,
t tuple<int,text>,
l frozen<list<int>>,
s frozen<set<int>>
)""" % name)
s.execute("""
CREATE TYPE %s_nested (
m frozen<map<int,text>>,
t tuple<int,text>,
l frozen<list<int>>,
s frozen<set<int>>,
u frozen<%s>
)""" % (name, name))
s.execute("""
CREATE TABLE %s (
k int PRIMARY KEY,
map_map map<frozen<map<int,int>>, frozen<map<int,int>>>,
map_set map<frozen<set<int>>, frozen<set<int>>>,
map_list map<frozen<list<int>>, frozen<list<int>>>,
map_tuple map<frozen<tuple<int, int>>, frozen<tuple<int>>>,
map_udt map<frozen<%s_nested>, frozen<%s>>,
)""" % (name, name, name))
validate = partial(self.insert_select_column, s, name)
validate('map_map', OrderedMap([({1: 1, 2: 2}, {3: 3, 4: 4}), ({5: 5, 6: 6}, {7: 7, 8: 8})]))
validate('map_set', OrderedMap([(set((1, 2)), set((3, 4))), (set((5, 6)), set((7, 8)))]))
validate('map_list', OrderedMap([([1, 2], [3, 4]), ([5, 6], [7, 8])]))
validate('map_tuple', OrderedMap([((1, 2), (3,)), ((4, 5), (6,))]))
value = nested_collection_udt({1: 'v1', 2: 'v2'}, (3, 'v3'), [4, 5, 6, 7], set((8, 9, 10)))
key = nested_collection_udt_nested(value.m, value.t, value.l, value.s, value)
key2 = nested_collection_udt_nested({3: 'v3'}, value.t, value.l, value.s, value)
validate('map_udt', OrderedMap([(key, value), (key2, value)]))
s.execute("DROP TABLE %s" % (name))
s.execute("DROP TYPE %s_nested" % (name))
s.execute("DROP TYPE %s" % (name))
s.shutdown()
def test_reading_composite_type(self):
s = self._session
s.execute("""
CREATE TABLE composites (
a int PRIMARY KEY,
@@ -761,3 +726,5 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT * FROM composites WHERE a = 0")[0]
self.assertEqual(0, result.a)
self.assertEqual(('abc',), result.b)
c.shutdown()