diff --git a/tests/integration/datatype_utils.py b/tests/integration/datatype_utils.py index 41a4c09e..bd76c36f 100644 --- a/tests/integration/datatype_utils.py +++ b/tests/integration/datatype_utils.py @@ -11,22 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from decimal import Decimal -import datetime -from uuid import UUID -import pytz +from datetime import datetime, date, time +from uuid import uuid1, uuid4 try: from blist import sortedset except ImportError: sortedset = set # noqa -DATA_TYPE_PRIMITIVES = [ +from cassandra.util import OrderedMap + +from tests.integration import get_server_versions + + +PRIMITIVE_DATATYPES = [ 'ascii', 'bigint', 'blob', 'boolean', - # 'counter', counters are not allowed inside tuples 'decimal', 'double', 'float', @@ -40,22 +44,28 @@ DATA_TYPE_PRIMITIVES = [ 'varint', ] -DATA_TYPE_NON_PRIMITIVE_NAMES = [ +COLLECTION_TYPES = [ 'list', 'set', 'map', - 'tuple' ] -def get_sample_data(): - """ - Create a standard set of sample inputs for testing. - """ +def update_datatypes(): + _cass_version, _cql_version = get_server_versions() + if _cass_version >= (2, 1, 0): + COLLECTION_TYPES.append('tuple') + + if _cass_version >= (2, 1, 5): + PRIMITIVE_DATATYPES.append('date') + PRIMITIVE_DATATYPES.append('time') + + +def get_sample_data(): sample_data = {} - for datatype in DATA_TYPE_PRIMITIVES: + for datatype in PRIMITIVE_DATATYPES: if datatype == 'ascii': sample_data[datatype] = 'ascii' @@ -68,10 +78,6 @@ def get_sample_data(): elif datatype == 'boolean': sample_data[datatype] = True - elif datatype == 'counter': - # Not supported in an insert statement - pass - elif datatype == 'decimal': sample_data[datatype] = Decimal('12.3E+7') @@ -91,13 +97,13 @@ def get_sample_data(): sample_data[datatype] = 'text' elif datatype == 'timestamp': - sample_data[datatype] = datetime.datetime.fromtimestamp(872835240, tz=pytz.timezone('America/New_York')).astimezone(pytz.UTC).replace(tzinfo=None) + sample_data[datatype] = datetime(2013, 12, 31, 23, 59, 59, 999000) elif datatype == 'timeuuid': - sample_data[datatype] = UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66') + sample_data[datatype] = uuid1() elif datatype == 'uuid': - sample_data[datatype] = UUID('067e6162-3b6f-4ae2-a171-2470b63dff00') + sample_data[datatype] = uuid4() elif datatype == 'varchar': sample_data[datatype] = 'varchar' @@ -105,35 +111,40 @@ def get_sample_data(): elif datatype == 'varint': sample_data[datatype] = int(str(2147483647) + '000') + elif datatype == 'date': + sample_data[datatype] = date(2015, 1, 15) + + elif datatype == 'time': + sample_data[datatype] = time(16, 47, 25, 7) + else: - raise Exception('Missing handling of %s.' % datatype) + raise Exception("Missing handling of {0}".format(datatype)) return sample_data SAMPLE_DATA = get_sample_data() + def get_sample(datatype): """ - Helper method to access created sample data + Helper method to access created sample data for primitive types """ return SAMPLE_DATA[datatype] -def get_nonprim_sample(non_prim_type, datatype): + +def get_collection_sample(collection_type, datatype): """ - Helper method to access created sample data for non-primitives + Helper method to access created sample data for collection types """ - if non_prim_type == 'list': + if collection_type == 'list': return [get_sample(datatype), get_sample(datatype)] - elif non_prim_type == 'set': + elif collection_type == 'set': return sortedset([get_sample(datatype)]) - elif non_prim_type == 'map': - if datatype == 'blob': - return {get_sample('ascii'): get_sample(datatype)} - else: - return {get_sample(datatype): get_sample(datatype)} - elif non_prim_type == 'tuple': + elif collection_type == 'map': + return OrderedMap([(get_sample(datatype), get_sample(datatype))]) + elif collection_type == 'tuple': return (get_sample(datatype),) else: - raise Exception('Missing handling of non-primitive type {0}.'.format(non_prim_type)) + raise Exception('Missing handling of non-primitive type {0}.'.format(collection_type)) diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index 2191db1f..b256e875 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tests.integration.datatype_utils import get_sample, DATA_TYPE_PRIMITIVES, DATA_TYPE_NON_PRIMITIVE_NAMES try: import unittest2 as unittest @@ -21,85 +20,53 @@ except ImportError: import logging log = logging.getLogger(__name__) -from collections import namedtuple -from decimal import Decimal -from datetime import datetime, date, time -from functools import partial +from datetime import datetime import six -from uuid import uuid1, uuid4 from cassandra import InvalidRequest from cassandra.cluster import Cluster from cassandra.cqltypes import Int32Type, EMPTY -from cassandra.query import dict_factory -from cassandra.util import OrderedMap, sortedset +from cassandra.query import dict_factory, ordered_dict_factory +from cassandra.util import sortedset from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION - -# defined in module scope for pickling in OrderedMap -nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's']) -nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u']) +from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \ + get_sample, get_collection_sample def setup_module(): use_singledc() + update_datatypes() class TypeTests(unittest.TestCase): - _types_table_created = False + def setUp(self): + self._cass_version, self._cql_version = get_server_versions() - @classmethod - def setup_class(cls): - cls._cass_version, cls._cql_version = get_server_versions() + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.session = self.cluster.connect() + self.session.execute("CREATE KEYSPACE typetests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}") + self.cluster.shutdown() - cls._col_types = ['text', - 'ascii', - 'bigint', - 'boolean', - 'decimal', - 'double', - 'float', - 'inet', - 'int', - 'list', - 'set', - 'map', - 'timestamp', - 'uuid', - 'timeuuid', - 'varchar', - 'varint'] + def tearDown(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.session = self.cluster.connect() + self.session.execute("DROP KEYSPACE typetests") + self.cluster.shutdown() - if cls._cass_version >= (2, 1, 4): - cls._col_types.extend(('date', 'time')) + def test_can_insert_blob_type_as_string(self): + """ + Tests that blob type in Cassandra does not map to string in Python + """ - cls._session = Cluster(protocol_version=PROTOCOL_VERSION).connect() - cls._session.execute("CREATE KEYSPACE typetests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}") - cls._session.set_keyspace("typetests") + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") - @classmethod - def teardown_class(cls): - cls._session.execute("DROP KEYSPACE typetests") - cls._session.cluster.shutdown() + s.execute("CREATE TABLE blobstring (a ascii PRIMARY KEY, b blob)") - def test_blob_type_as_string(self): - s = self._session - - s.execute(""" - CREATE TABLE blobstring ( - a ascii, - b blob, - PRIMARY KEY (a) - ) - """) - - params = [ - 'key1', - b'blobyblob' - ] - - query = 'INSERT INTO blobstring (a, b) VALUES (%s, %s)' + params = ['key1', b'blobyblob'] + query = "INSERT INTO blobstring (a, b) VALUES (%s, %s)" # In python 3, the 'bytes' type is treated as a blob, so we can # correctly encode it with hex notation. @@ -118,243 +85,277 @@ class TypeTests(unittest.TestCase): params[1] = params[1].encode('hex') s.execute(query, params) - expected_vals = [ - 'key1', - bytearray(b'blobyblob') - ] - results = s.execute("SELECT * FROM blobstring") - - for expected, actual in zip(expected_vals, results[0]): + results = s.execute("SELECT * FROM blobstring")[0] + for expected, actual in zip(params, results): self.assertEqual(expected, actual) - def test_blob_type_as_bytearray(self): - s = self._session - s.execute(""" - CREATE TABLE blobbytes ( - a ascii, - b blob, - PRIMARY KEY (a) - ) - """) + c.shutdown() - params = [ - 'key1', - bytearray(b'blob1') - ] + def test_can_insert_blob_type_as_bytearray(self): + """ + Tests that blob type in Cassandra maps to bytearray in Python + """ - query = 'INSERT INTO blobbytes (a, b) VALUES (%s, %s);' - s.execute(query, params) + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") - expected_vals = [ - 'key1', - bytearray(b'blob1') - ] + s.execute("CREATE TABLE blobbytes (a ascii PRIMARY KEY, b blob)") - results = s.execute("SELECT * FROM blobbytes") + params = ['key1', bytearray(b'blob1')] + s.execute("INSERT INTO blobbytes (a, b) VALUES (%s, %s)", params) - for expected, actual in zip(expected_vals, results[0]): + results = s.execute("SELECT * FROM blobbytes")[0] + for expected, actual in zip(params, results): self.assertEqual(expected, actual) - def _create_all_types_table(self): - if not self._types_table_created: - TypeTests._col_names = ["%s_col" % col_type.translate(None, '<> ,') for col_type in self._col_types] - cql = "CREATE TABLE alltypes ( key int PRIMARY KEY, %s)" % ','.join("%s %s" % name_type for name_type in zip(self._col_names, self._col_types)) - self._session.execute(cql) - TypeTests._types_table_created = True + c.shutdown() - def test_basic_types(self): + def test_can_insert_primitive_datatypes(self): + """ + Test insertion of all datatype primitives + """ - s = self._session.cluster.connect() - s.set_keyspace(self._session.keyspace) + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") - self._create_all_types_table() + # create table + alpha_type_list = ["zz int PRIMARY KEY"] + col_names = ["zz"] + start_index = ord('a') + for i, datatype in enumerate(PRIMITIVE_DATATYPES): + alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype)) + col_names.append(chr(start_index + i)) - v1_uuid = uuid1() - v4_uuid = uuid4() - mydatetime = datetime(2013, 12, 31, 23, 59, 59, 999000) + s.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list))) - # this could use some rework tying column types to names (instead of relying on position) - params = [ - "text", - "ascii", - 12345678923456789, # bigint - True, # boolean - Decimal('1.234567890123456789'), # decimal - 0.000244140625, # double - 1.25, # float - "1.2.3.4", # inet - 12345, # int - ['a', 'b', 'c'], # list collection - set([1, 2, 3]), # set collection - {'a': 1, 'b': 2}, # map collection - mydatetime, # timestamp - v4_uuid, # uuid - v1_uuid, # timeuuid - u"sometext\u1234", # varchar - 123456789123456789123456789, # varint - ] + # create the input + params = [0] + for datatype in PRIMITIVE_DATATYPES: + params.append((get_sample(datatype))) - expected_vals = [ - "text", - "ascii", - 12345678923456789, # bigint - True, # boolean - Decimal('1.234567890123456789'), # decimal - 0.000244140625, # double - 1.25, # float - "1.2.3.4", # inet - 12345, # int - ['a', 'b', 'c'], # list collection - sortedset((1, 2, 3)), # set collection - {'a': 1, 'b': 2}, # map collection - mydatetime, # timestamp - v4_uuid, # uuid - v1_uuid, # timeuuid - u"sometext\u1234", # varchar - 123456789123456789123456789, # varint - ] + # insert into table as a simple statement + columns_string = ', '.join(col_names) + placeholders = ', '.join(["%s"] * len(col_names)) + s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params) - if self._cass_version >= (2, 1, 4): - mydate = date(2015, 1, 15) - mytime = time(16, 47, 25, 7) - - params.append(mydate) - params.append(mytime) - - expected_vals.append(mydate) - expected_vals.append(mytime) - - columns_string = ','.join(self._col_names) - placeholders = ', '.join(["%s"] * len(self._col_names)) - s.execute("INSERT INTO alltypes (key, %s) VALUES (0, %s)" % - (columns_string, placeholders), params) - - results = s.execute("SELECT %s FROM alltypes WHERE key=0" % columns_string) - - for expected, actual in zip(expected_vals, results[0]): + # verify data + results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0] + for expected, actual in zip(params, results): self.assertEqual(actual, expected) # try the same thing with a prepared statement - placeholders = ','.join(["?"] * len(self._col_names)) - prepared = s.prepare("INSERT INTO alltypes (key, %s) VALUES (1, %s)" % - (columns_string, placeholders)) - s.execute(prepared.bind(params)) + placeholders = ','.join(["?"] * len(col_names)) + insert = s.prepare("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders)) + s.execute(insert.bind(params)) - results = s.execute("SELECT %s FROM alltypes WHERE key=1" % columns_string) - - for expected, actual in zip(expected_vals, results[0]): + # verify data + results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0] + for expected, actual in zip(params, results): self.assertEqual(actual, expected) - # query with prepared statement - prepared = s.prepare("SELECT %s FROM alltypes WHERE key=?" % columns_string) - results = s.execute(prepared.bind((1,))) - - for expected, actual in zip(expected_vals, results[0]): + # verify data with prepared statement query + select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string)) + results = s.execute(select.bind([0]))[0] + for expected, actual in zip(params, results): self.assertEqual(actual, expected) - # query with prepared statement, no explicit columns - s.row_factory = dict_factory - prepared = s.prepare("SELECT * FROM alltypes") - results = s.execute(prepared.bind(())) + # verify data with with prepared statement, use dictionary with no explicit columns + s.row_factory = ordered_dict_factory + select = s.prepare("SELECT * FROM alltypes") + results = s.execute(select)[0] - row = results[0] - for expected, name in zip(expected_vals, self._col_names): - self.assertEqual(row[name], expected) + for expected, actual in zip(params, results.values()): + self.assertEqual(actual, expected) - s.shutdown() + c.shutdown() - def test_empty_strings_and_nones(self): - s = self._session.cluster.connect() - s.set_keyspace(self._session.keyspace) - s.row_factory = dict_factory + def test_can_insert_collection_datatypes(self): + """ + Test insertion of all collection types + """ - self._create_all_types_table() + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") + # use tuple encoding, to convert native python tuple into raw CQL + s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple - columns_string = ','.join(self._col_names) - s.execute("INSERT INTO alltypes (key) VALUES (2)") - results = s.execute("SELECT %s FROM alltypes WHERE key=2" % columns_string) - self.assertTrue(all(x is None for x in results[0].values())) + # create table + alpha_type_list = ["zz int PRIMARY KEY"] + col_names = ["zz"] + start_index = ord('a') + for i, collection_type in enumerate(COLLECTION_TYPES): + for j, datatype in enumerate(PRIMITIVE_DATATYPES): + if collection_type == "map": + type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j), + collection_type, datatype) + elif collection_type == "tuple": + type_string = "{0}_{1} frozen<{2}<{3}>>".format(chr(start_index + i), chr(start_index + j), + collection_type, datatype) + else: + type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j), + collection_type, datatype) + alpha_type_list.append(type_string) + col_names.append("{0}_{1}".format(chr(start_index + i), chr(start_index + j))) - prepared = s.prepare("SELECT %s FROM alltypes WHERE key=?" % columns_string) - results = s.execute(prepared.bind((2,))) - self.assertTrue(all(x is None for x in results[0].values())) + s.execute("CREATE TABLE allcoltypes ({0})".format(', '.join(alpha_type_list))) + columns_string = ', '.join(col_names) - # insert empty strings for string-like fields and fetch them - expected_values = {'text_col': '', 'ascii_col': '', 'varchar_col': '', 'listtext_col': [''], 'maptextint_col': OrderedMap({'': 3})} + # create the input for simple statement + params = [0] + for collection_type in COLLECTION_TYPES: + for datatype in PRIMITIVE_DATATYPES: + params.append((get_collection_sample(collection_type, datatype))) + + # insert into table as a simple statement + placeholders = ', '.join(["%s"] * len(col_names)) + s.execute("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params) + + # verify data + results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string))[0] + for expected, actual in zip(params, results): + self.assertEqual(actual, expected) + + # create the input for prepared statement + params = [0] + for collection_type in COLLECTION_TYPES: + for datatype in PRIMITIVE_DATATYPES: + params.append((get_collection_sample(collection_type, datatype))) + + # try the same thing with a prepared statement + placeholders = ','.join(["?"] * len(col_names)) + insert = s.prepare("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders)) + s.execute(insert.bind(params)) + + # verify data + results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string))[0] + for expected, actual in zip(params, results): + self.assertEqual(actual, expected) + + # verify data with prepared statement query + select = s.prepare("SELECT {0} FROM allcoltypes WHERE zz=?".format(columns_string)) + results = s.execute(select.bind([0]))[0] + for expected, actual in zip(params, results): + self.assertEqual(actual, expected) + + # verify data with with prepared statement, use dictionary with no explicit columns + s.row_factory = ordered_dict_factory + select = s.prepare("SELECT * FROM allcoltypes") + results = s.execute(select)[0] + + for expected, actual in zip(params, results.values()): + self.assertEqual(actual, expected) + + c.shutdown() + + def test_can_insert_empty_strings_and_nulls(self): + """ + Test insertion of empty strings and null values + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") + + # create table + alpha_type_list = ["zz int PRIMARY KEY"] + col_names = [] + start_index = ord('a') + for i, datatype in enumerate(PRIMITIVE_DATATYPES): + alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype)) + col_names.append(chr(start_index + i)) + + s.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list))) + + # verify all types initially null with simple statement + columns_string = ','.join(col_names) + s.execute("INSERT INTO alltypes (zz) VALUES (2)") + results = s.execute("SELECT {0} FROM alltypes WHERE zz=2".format(columns_string))[0] + self.assertTrue(all(x is None for x in results)) + + # verify all types initially null with prepared statement + select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string)) + results = s.execute(select.bind([2]))[0] + self.assertTrue(all(x is None for x in results)) + + # insert empty strings for string-like fields + expected_values = {'j': '', 'a': '', 'n': ''} columns_string = ','.join(expected_values.keys()) placeholders = ','.join(["%s"] * len(expected_values)) - s.execute("INSERT INTO alltypes (key, %s) VALUES (3, %s)" % (columns_string, placeholders), expected_values.values()) - self.assertEqual(expected_values, - s.execute("SELECT %s FROM alltypes WHERE key=3" % columns_string)[0]) - self.assertEqual(expected_values, - s.execute(s.prepare("SELECT %s FROM alltypes WHERE key=?" % columns_string), (3,))[0]) + s.execute("INSERT INTO alltypes (zz, {0}) VALUES (3, {1})".format(columns_string, placeholders), expected_values.values()) + + # verify string types empty with simple statement + results = s.execute("SELECT {0} FROM alltypes WHERE zz=3".format(columns_string))[0] + for expected, actual in zip(expected_values.values(), results): + self.assertEqual(actual, expected) + + # verify string types empty with prepared statement + results = s.execute(s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string)), [3])[0] + for expected, actual in zip(expected_values.values(), results): + self.assertEqual(actual, expected) # non-string types shouldn't accept empty strings - for col in ('bigint_col', 'boolean_col', 'decimal_col', 'double_col', - 'float_col', 'int_col', 'listtext_col', 'setint_col', - 'maptextint_col', 'uuid_col', 'timeuuid_col', 'varint_col'): - query = "INSERT INTO alltypes (key, %s) VALUES (4, %%s)" % col - try: + for col in ('b', 'd', 'e', 'f', 'g', 'i', 'l', 'm', 'o'): + query = "INSERT INTO alltypes (zz, {0}) VALUES (4, %s)".format(col) + with self.assertRaises(InvalidRequest): s.execute(query, ['']) - except InvalidRequest: - pass - else: - self.fail("Expected an InvalidRequest error when inserting an " - "emptry string for column %s" % (col, )) - prepared = s.prepare("INSERT INTO alltypes (key, %s) VALUES (4, ?)" % col) - try: - s.execute(prepared, ['']) - except TypeError: - pass - else: - self.fail("Expected an InvalidRequest error when inserting an " - "emptry string for column %s with a prepared statement" % (col, )) + insert = s.prepare("INSERT INTO alltypes (zz, {0}) VALUES (4, ?)".format(col)) + with self.assertRaises(TypeError): + s.execute(insert, ['']) - # insert values for all columns - values = ['text', 'ascii', 1, True, Decimal('1.0'), 0.1, 0.1, - "1.2.3.4", 1, ['a'], set([1]), {'a': 1}, - datetime.now(), uuid4(), uuid1(), 'a', 1] - if self._cass_version >= (2, 1, 4): - values.append('2014-01-01') - values.append('01:02:03.456789012') + # verify that Nones can be inserted and overwrites existing data + # create the input + params = [] + for datatype in PRIMITIVE_DATATYPES: + params.append((get_sample(datatype))) - columns_string = ','.join(self._col_names) - placeholders = ','.join(["%s"] * len(self._col_names)) - insert = "INSERT INTO alltypes (key, %s) VALUES (5, %s)" % (columns_string, placeholders) - s.execute(insert, values) + # insert the data + columns_string = ','.join(col_names) + placeholders = ','.join(["%s"] * len(col_names)) + simple_insert = "INSERT INTO alltypes (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders) + s.execute(simple_insert, params) # then insert None, which should null them out - null_values = [None] * len(self._col_names) - s.execute(insert, null_values) + null_values = [None] * len(col_names) + s.execute(simple_insert, null_values) - select = "SELECT %s FROM alltypes WHERE key=5" % columns_string - results = s.execute(select) - self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) + # check via simple statement + query = "SELECT {0} FROM alltypes WHERE zz=5".format(columns_string) + results = s.execute(query)[0] + for col in results: + self.assertEqual(None, col) - prepared = s.prepare(select) - results = s.execute(prepared.bind(())) - self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) + # check via prepared statement + select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string)) + results = s.execute(select.bind([5]))[0] + for col in results: + self.assertEqual(None, col) # do the same thing again, but use a prepared statement to insert the nulls - s.execute(insert, values) + s.execute(simple_insert, params) - placeholders = ','.join(["?"] * len(self._col_names)) - prepared = s.prepare("INSERT INTO alltypes (key, %s) VALUES (5, %s)" % (columns_string, placeholders)) - s.execute(prepared, null_values) + placeholders = ','.join(["?"] * len(col_names)) + insert = s.prepare("INSERT INTO alltypes (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders)) + s.execute(insert, null_values) - results = s.execute(select) - self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) + results = s.execute(query)[0] + for col in results: + self.assertEqual(None, col) - prepared = s.prepare(select) - results = s.execute(prepared.bind(())) - self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) + results = s.execute(select.bind([5]))[0] + for col in results: + self.assertEqual(None, col) s.shutdown() - def test_empty_values(self): - s = self._session + def test_can_insert_empty_values_for_int32(self): + """ + Ensure Int32Type supports empty values + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") + s.execute("CREATE TABLE empty_values (a text PRIMARY KEY, b int)") s.execute("INSERT INTO empty_values (a, b) VALUES ('a', blobAsInt(0x))") try: @@ -364,8 +365,13 @@ class TypeTests(unittest.TestCase): finally: Int32Type.support_empty_values = False - def test_timezone_aware_datetimes(self): - """ Ensure timezone-aware datetimes are converted to timestamps correctly """ + c.shutdown() + + def test_timezone_aware_datetimes_are_timestamps(self): + """ + Ensure timezone-aware datetimes are converted to timestamps correctly + """ + try: import pytz except ImportError as exc: @@ -375,22 +381,25 @@ class TypeTests(unittest.TestCase): eastern_tz = pytz.timezone('US/Eastern') eastern_tz.localize(dt) - s = self._session + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") s.execute("CREATE TABLE tz_aware (a ascii PRIMARY KEY, b timestamp)") # test non-prepared statement - s.execute("INSERT INTO tz_aware (a, b) VALUES ('key1', %s)", parameters=(dt,)) + s.execute("INSERT INTO tz_aware (a, b) VALUES ('key1', %s)", [dt]) result = s.execute("SELECT b FROM tz_aware WHERE a='key1'")[0].b self.assertEqual(dt.utctimetuple(), result.utctimetuple()) # test prepared statement - prepared = s.prepare("INSERT INTO tz_aware (a, b) VALUES ('key2', ?)") - s.execute(prepared, parameters=(dt,)) + insert = s.prepare("INSERT INTO tz_aware (a, b) VALUES ('key2', ?)") + s.execute(insert.bind([dt])) result = s.execute("SELECT b FROM tz_aware WHERE a='key2'")[0].b self.assertEqual(dt.utctimetuple(), result.utctimetuple()) - def test_tuple_type(self): + c.shutdown() + + def test_can_insert_tuples(self): """ Basic test of tuple functionality """ @@ -398,8 +407,8 @@ class TypeTests(unittest.TestCase): if self._cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - s = self._session.cluster.connect() - s.set_keyspace(self._session.keyspace) + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") # use this encoder in order to insert tuples s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple @@ -439,9 +448,9 @@ class TypeTests(unittest.TestCase): self.assertEqual(partial_result, s.execute(prepared, (4,))[0].b) self.assertEqual(subpartial_result, s.execute(prepared, (5,))[0].b) - s.shutdown() + c.shutdown() - def test_tuple_type_varying_lengths(self): + def test_can_insert_tuples_with_varying_lengths(self): """ Test tuple types of lengths of 1, 2, 3, and 384 to ensure edge cases work as expected. @@ -450,8 +459,8 @@ class TypeTests(unittest.TestCase): if self._cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - s = self._session.cluster.connect() - s.set_keyspace(self._session.keyspace) + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") # set the row_factory to dict_factory for programmatic access # set the encoder for tuples for the ability to write tuples @@ -479,9 +488,9 @@ class TypeTests(unittest.TestCase): result = s.execute("SELECT v_%s FROM tuple_lengths WHERE k=0", (i,))[0] self.assertEqual(tuple(created_tuple), result['v_%s' % i]) - s.shutdown() + c.shutdown() - def test_tuple_primitive_subtypes(self): + def test_can_insert_tuples_all_primitive_datatypes(self): """ Ensure tuple subtypes are appropriately handled. """ @@ -489,28 +498,28 @@ class TypeTests(unittest.TestCase): if self._cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - s = self._session.cluster.connect() - s.set_keyspace(self._session.keyspace) + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple s.execute("CREATE TABLE tuple_primitive (" "k int PRIMARY KEY, " - "v frozen>)" % ','.join(DATA_TYPE_PRIMITIVES)) + "v frozen>)" % ','.join(PRIMITIVE_DATATYPES)) - for i in range(len(DATA_TYPE_PRIMITIVES)): + for i in range(len(PRIMITIVE_DATATYPES)): # create tuples to be written and ensure they match with the expected response # responses have trailing None values for every element that has not been written - created_tuple = [get_sample(DATA_TYPE_PRIMITIVES[j]) for j in range(i + 1)] - response_tuple = tuple(created_tuple + [None for j in range(len(DATA_TYPE_PRIMITIVES) - i - 1)]) + created_tuple = [get_sample(PRIMITIVE_DATATYPES[j]) for j in range(i + 1)] + response_tuple = tuple(created_tuple + [None for j in range(len(PRIMITIVE_DATATYPES) - i - 1)]) written_tuple = tuple(created_tuple) s.execute("INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, written_tuple)) result = s.execute("SELECT v FROM tuple_primitive WHERE k=%s", (i,))[0] self.assertEqual(response_tuple, result.v) - s.shutdown() + c.shutdown() - def test_tuple_non_primitive_subtypes(self): + def test_can_insert_tuples_all_collection_datatypes(self): """ Ensure tuple subtypes are appropriately handled for maps, sets, and lists. """ @@ -518,8 +527,8 @@ class TypeTests(unittest.TestCase): if self._cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - s = self._session.cluster.connect() - s.set_keyspace(self._session.keyspace) + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") # set the row_factory to dict_factory for programmatic access # set the encoder for tuples for the ability to write tuples @@ -529,15 +538,15 @@ class TypeTests(unittest.TestCase): values = [] # create list values - for datatype in DATA_TYPE_PRIMITIVES: + for datatype in PRIMITIVE_DATATYPES: values.append('v_{} frozen>>'.format(len(values), datatype)) # create set values - for datatype in DATA_TYPE_PRIMITIVES: + for datatype in PRIMITIVE_DATATYPES: values.append('v_{} frozen>>'.format(len(values), datatype)) # create map values - for datatype in DATA_TYPE_PRIMITIVES: + for datatype in PRIMITIVE_DATATYPES: datatype_1 = datatype_2 = datatype if datatype == 'blob': # unhashable type: 'bytearray' @@ -545,9 +554,9 @@ class TypeTests(unittest.TestCase): values.append('v_{} frozen>>'.format(len(values), datatype_1, datatype_2)) # make sure we're testing all non primitive data types in the future - if set(DATA_TYPE_NON_PRIMITIVE_NAMES) != set(['tuple', 'list', 'map', 'set']): + if set(COLLECTION_TYPES) != set(['tuple', 'list', 'map', 'set']): raise NotImplemented('Missing datatype not implemented: {}'.format( - set(DATA_TYPE_NON_PRIMITIVE_NAMES) - set(['tuple', 'list', 'map', 'set']) + set(COLLECTION_TYPES) - set(['tuple', 'list', 'map', 'set']) )) # create table @@ -557,7 +566,7 @@ class TypeTests(unittest.TestCase): i = 0 # test tuple> - for datatype in DATA_TYPE_PRIMITIVES: + for datatype in PRIMITIVE_DATATYPES: created_tuple = tuple([[get_sample(datatype)]]) s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) @@ -566,7 +575,7 @@ class TypeTests(unittest.TestCase): i += 1 # test tuple> - for datatype in DATA_TYPE_PRIMITIVES: + for datatype in PRIMITIVE_DATATYPES: created_tuple = tuple([sortedset([get_sample(datatype)])]) s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) @@ -575,7 +584,7 @@ class TypeTests(unittest.TestCase): i += 1 # test tuple> - for datatype in DATA_TYPE_PRIMITIVES: + for datatype in PRIMITIVE_DATATYPES: if datatype == 'blob': # unhashable type: 'bytearray' created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}]) @@ -587,7 +596,7 @@ class TypeTests(unittest.TestCase): result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,))[0] self.assertEqual(created_tuple, result['v_%s' % i]) i += 1 - s.shutdown() + c.shutdown() def nested_tuples_schema_helper(self, depth): """ @@ -609,7 +618,7 @@ class TypeTests(unittest.TestCase): else: return (self.nested_tuples_creator_helper(depth - 1), ) - def test_nested_tuples(self): + def test_can_insert_nested_tuples(self): """ Ensure nested are appropriately handled. """ @@ -617,8 +626,8 @@ class TypeTests(unittest.TestCase): if self._cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - s = self._session.cluster.connect() - s.set_keyspace(self._session.keyspace) + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") # set the row_factory to dict_factory for programmatic access # set the encoder for tuples for the ability to write tuples @@ -647,16 +656,18 @@ class TypeTests(unittest.TestCase): # verify tuple was written and read correctly result = s.execute("SELECT v_%s FROM nested_tuples WHERE k=%s", (i, i))[0] self.assertEqual(created_tuple, result['v_%s' % i]) - s.shutdown() + c.shutdown() - def test_tuples_with_nulls(self): + def test_can_insert_tuples_with_nulls(self): """ Test tuples with null and empty string fields. """ + if self._cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - s = self._session + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") s.execute("CREATE TABLE tuples_nulls (k int PRIMARY KEY, t frozen>)") @@ -675,75 +686,29 @@ class TypeTests(unittest.TestCase): self.assertEqual(('', None, None, b''), result[0].t) self.assertEqual(('', None, None, b''), s.execute(read)[0].t) - def test_unicode_query_string(self): - s = self._session + c.shutdown() + + def test_can_insert_unicode_query_string(self): + """ + Test to ensure unicode strings can be used in a query + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") query = u"SELECT * FROM system.schema_columnfamilies WHERE keyspace_name = 'ef\u2052ef' AND columnfamily_name = %s" s.execute(query, (u"fe\u2051fe",)) - def insert_select_column(self, session, table_name, column_name, value): - insert = session.prepare("INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name)) - session.execute(insert, (0, value)) - result = session.execute("SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,))[0][0] - self.assertEqual(result, value) + c.shutdown() - def test_nested_collections(self): + def test_can_read_composite_type(self): + """ + Test to ensure that CompositeTypes can be used in a query + """ - if self._cass_version < (2, 1, 3): - raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3") + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect("typetests") - if PROTOCOL_VERSION < 3: - raise unittest.SkipTest("Protocol version > 3 required for nested collections") - - name = self._testMethodName - - s = self._session.cluster.connect() - s.set_keyspace(self._session.keyspace) - s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple - - s.execute(""" - CREATE TYPE %s ( - m frozen>, - t tuple, - l frozen>, - s frozen> - )""" % name) - s.execute(""" - CREATE TYPE %s_nested ( - m frozen>, - t tuple, - l frozen>, - s frozen>, - u frozen<%s> - )""" % (name, name)) - s.execute(""" - CREATE TABLE %s ( - k int PRIMARY KEY, - map_map map>, frozen>>, - map_set map>, frozen>>, - map_list map>, frozen>>, - map_tuple map>, frozen>>, - map_udt map, frozen<%s>>, - )""" % (name, name, name)) - - validate = partial(self.insert_select_column, s, name) - validate('map_map', OrderedMap([({1: 1, 2: 2}, {3: 3, 4: 4}), ({5: 5, 6: 6}, {7: 7, 8: 8})])) - validate('map_set', OrderedMap([(set((1, 2)), set((3, 4))), (set((5, 6)), set((7, 8)))])) - validate('map_list', OrderedMap([([1, 2], [3, 4]), ([5, 6], [7, 8])])) - validate('map_tuple', OrderedMap([((1, 2), (3,)), ((4, 5), (6,))])) - - value = nested_collection_udt({1: 'v1', 2: 'v2'}, (3, 'v3'), [4, 5, 6, 7], set((8, 9, 10))) - key = nested_collection_udt_nested(value.m, value.t, value.l, value.s, value) - key2 = nested_collection_udt_nested({3: 'v3'}, value.t, value.l, value.s, value) - validate('map_udt', OrderedMap([(key, value), (key2, value)])) - - s.execute("DROP TABLE %s" % (name)) - s.execute("DROP TYPE %s_nested" % (name)) - s.execute("DROP TYPE %s" % (name)) - s.shutdown() - - def test_reading_composite_type(self): - s = self._session s.execute(""" CREATE TABLE composites ( a int PRIMARY KEY, @@ -761,3 +726,5 @@ class TypeTests(unittest.TestCase): result = s.execute("SELECT * FROM composites WHERE a = 0")[0] self.assertEqual(0, result.a) self.assertEqual(('abc',), result.b) + + c.shutdown()