tests for non-primitive datatypes for UDTs

This commit is contained in:
Kishan Karunaratne
2014-07-21 16:50:21 -07:00
parent 938237f6f6
commit 1692c375f7
2 changed files with 87 additions and 5 deletions

View File

@@ -16,6 +16,11 @@ import datetime
from uuid import UUID
import pytz
try:
from blist import sortedset
except ImportError:
sortedset = set # noqa
DATA_TYPE_PRIMITIVES = [
'ascii',
'bigint',
@@ -113,3 +118,19 @@ def get_sample(datatype):
"""
return SAMPLE_DATA[datatype]
def get_nonprim_sample(non_prim_type):
"""
Helper method to access created sample data for non-primitives
"""
if non_prim_type == 'list':
return ['text', 'text']
elif non_prim_type == 'set':
return sortedset(['text'])
elif non_prim_type == 'map':
return {'text': 'text'}
elif non_prim_type == 'tuple':
return ('text', 'text')
else:
raise Exception('Missing handling of non-primitive type {0}.'.format(non_prim_type))

View File

@@ -17,6 +17,11 @@ try:
except ImportError:
import unittest # noqa
# try:
# from blist import sortedset
# except ImportError:
# sortedset = set # noqa
import logging
log = logging.getLogger(__name__)
@@ -25,7 +30,8 @@ from collections import namedtuple
from cassandra.cluster import Cluster
from tests.integration import get_server_versions, PROTOCOL_VERSION
from tests.integration.long.datatype_utils import get_sample, DATA_TYPE_PRIMITIVES
from tests.integration.datatype_utils import get_sample, get_nonprim_sample,\
DATA_TYPE_PRIMITIVES, DATA_TYPE_NON_PRIMITIVE_NAMES
class TypeTests(unittest.TestCase):
@@ -235,7 +241,7 @@ class TypeTests(unittest.TestCase):
self.assertRaises(UserTypeDoesNotExist, c.register_user_type, "some_bad_keyspace", "user", User)
self.assertRaises(UserTypeDoesNotExist, c.register_user_type, "system", "user", User)
def test_datatypes(self):
def test_primitive_datatypes(self):
"""
Test for inserting various types of DATA_TYPE_PRIMITIVES into UDT's
"""
@@ -244,10 +250,10 @@ class TypeTests(unittest.TestCase):
# create keyspace
s.execute("""
CREATE KEYSPACE test_datatypes
CREATE KEYSPACE test_primitive_datatypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("test_datatypes")
s.set_keyspace("test_primitive_datatypes")
# create UDT
alpha_type_list = []
@@ -267,7 +273,7 @@ class TypeTests(unittest.TestCase):
for i in range(ord('a'), ord('a')+len(DATA_TYPE_PRIMITIVES)):
alphabet_list.append('{}'.format(chr(i)))
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
c.register_user_type("test_datatypes", "alldatatypes", Alldatatypes)
c.register_user_type("test_primitive_datatypes", "alldatatypes", Alldatatypes)
# insert UDT data
params = []
@@ -286,3 +292,58 @@ class TypeTests(unittest.TestCase):
self.assertEqual(expected, actual)
c.shutdown()
def test_nonprimitive_datatypes(self):
"""
Test for inserting various types of DATA_TYPE_NON_PRIMITIVE into UDT's
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
# create keyspace
s.execute("""
CREATE KEYSPACE test_nonprimitive_datatypes
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("test_nonprimitive_datatypes")
# create UDT
alpha_type_list = []
start_index = ord('a')
for i, datatype in enumerate(DATA_TYPE_NON_PRIMITIVE_NAMES):
if datatype == "map":
alpha_type_list.append("{0} {1}<text, text>".format(chr(start_index + i), datatype))
else:
alpha_type_list.append("{0} {1}<text>".format(chr(start_index + i), datatype))
s.execute("""
CREATE TYPE alldatatypes ({0})
""".format(', '.join(alpha_type_list))
)
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b alldatatypes)")
# register UDT
alphabet_list = []
for i in range(ord('a'), ord('a')+len(DATA_TYPE_NON_PRIMITIVE_NAMES)):
alphabet_list.append('{}'.format(chr(i)))
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
c.register_user_type("test_nonprimitive_datatypes", "alldatatypes", Alldatatypes)
# insert UDT data
params = []
for datatype in DATA_TYPE_NON_PRIMITIVE_NAMES:
params.append((get_nonprim_sample(datatype)))
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(insert, (0, Alldatatypes(*params)))
# retrieve and verify data
results = s.execute("SELECT * FROM mytable")
self.assertEqual(1, len(results))
row = results[0].b
for expected, actual in zip(params, row):
self.assertEqual(expected, actual)
c.shutdown()