Merge branch 'tuple-test-merge'

This commit is contained in:
Tyler Hobbs
2014-07-17 21:46:48 -05:00
2 changed files with 174 additions and 6 deletions

View File

@@ -0,0 +1,106 @@
# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from decimal import Decimal
import datetime
from uuid import UUID
import pytz
DATA_TYPE_PRIMITIVES = [
'ascii',
'bigint',
'blob',
'boolean',
# 'counter', counters are not allowed inside tuples
'decimal',
'double',
'float',
'inet',
'int',
'text',
'timestamp',
'timeuuid',
'uuid',
'varchar',
'varint',
]
DATA_TYPE_NON_PRIMITIVE_NAMES = [
'list',
'set',
'map',
]
def get_sample_data():
sample_data = {}
for datatype in DATA_TYPE_PRIMITIVES:
if datatype == 'ascii':
sample_data[datatype] = 'ascii'
elif datatype == 'bigint':
sample_data[datatype] = 2 ** 63 - 1
elif datatype == 'blob':
sample_data[datatype] = bytearray(b'hello world')
elif datatype == 'boolean':
sample_data[datatype] = True
elif datatype == 'counter':
# Not supported in an insert statement
pass
elif datatype == 'decimal':
sample_data[datatype] = Decimal('12.3E+7')
elif datatype == 'double':
sample_data[datatype] = 1.23E+8
elif datatype == 'float':
sample_data[datatype] = 3.4028234663852886e+38
elif datatype == 'inet':
sample_data[datatype] = '123.123.123.123'
elif datatype == 'int':
sample_data[datatype] = 2147483647
elif datatype == 'text':
sample_data[datatype] = 'text'
elif datatype == 'timestamp':
sample_data[datatype] = datetime.datetime.fromtimestamp(872835240, tz=pytz.timezone('America/New_York')).astimezone(pytz.UTC).replace(tzinfo=None)
elif datatype == 'timeuuid':
sample_data[datatype] = UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66')
elif datatype == 'uuid':
sample_data[datatype] = UUID('067e6162-3b6f-4ae2-a171-2470b63dff00')
elif datatype == 'varchar':
sample_data[datatype] = 'varchar'
elif datatype == 'varint':
sample_data[datatype] = int(str(2147483647) + '000')
else:
raise Exception('Missing handling of %s.' % datatype)
return sample_data
SAMPLE_DATA = get_sample_data()
def get_sample(datatype):
return SAMPLE_DATA[datatype]

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.integration.datatype_utils import get_sample, DATA_TYPE_PRIMITIVES
try:
import unittest2 as unittest
@@ -87,8 +88,8 @@ class TypeTests(unittest.TestCase):
s.execute(query, params)
expected_vals = [
'key1',
bytearray(b'blobyblob')
'key1',
bytearray(b'blobyblob')
]
results = s.execute("SELECT * FROM mytable")
@@ -428,14 +429,75 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT b FROM mytable WHERE a=1")[0]
self.assertEqual(partial_result, result.b)
subpartial = ('zoo',)
subpartial_result = subpartial + (None, None)
s.execute("INSERT INTO mytable (a, b) VALUES (2, %s)", parameters=(subpartial,))
result = s.execute("SELECT b FROM mytable WHERE a=2")[0]
self.assertEqual(subpartial_result, result.b)
# test prepared statement
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(prepared, parameters=(2, complete))
s.execute(prepared, parameters=(3, partial))
s.execute(prepared, parameters=(3, complete))
s.execute(prepared, parameters=(4, partial))
s.execute(prepared, parameters=(5, subpartial))
prepared = s.prepare("SELECT b FROM mytable WHERE a=?")
self.assertEqual(complete, s.execute(prepared, (2,))[0].b)
self.assertEqual(partial_result, s.execute(prepared, (3,))[0].b)
self.assertEqual(complete, s.execute(prepared, (3,))[0].b)
self.assertEqual(partial_result, s.execute(prepared, (4,))[0].b)
self.assertEqual(subpartial_result, s.execute(prepared, (5,))[0].b)
def test_tuple_type_varying_lengths(self):
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.row_factory = dict_factory
s.encoders[tuple] = cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_type_varying_lengths
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_tuple_type_varying_lengths")
lengths = (1, 2, 3, 384)
value_schema = []
for i in lengths:
value_schema += [' v_%s tuple<%s>' % (i, ', '.join(['int'] * i))]
s.execute("CREATE TABLE mytable (k int PRIMARY KEY, %s)" % (', '.join(value_schema),))
for i in lengths:
created_tuple = tuple(range(0, i))
s.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", (i, created_tuple))
result = s.execute("SELECT v_%s FROM mytable WHERE k=0", (i,))[0]
self.assertEqual(tuple(created_tuple), result['v_%s' % i])
def test_tuple_subtypes(self):
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.encoders[tuple] = cql_encode_tuple
s.execute("""CREATE KEYSPACE test_tuple_types
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_tuple_types")
s.execute("CREATE TABLE mytable ("
"k int PRIMARY KEY, "
"v tuple<%s>)" % ','.join(DATA_TYPE_PRIMITIVES))
for i in range(len(DATA_TYPE_PRIMITIVES)):
created_tuple = [get_sample(DATA_TYPE_PRIMITIVES[j]) for j in range(i + 1)]
response_tuple = tuple(created_tuple + [None for j in range(len(DATA_TYPE_PRIMITIVES) - i - 1)])
written_tuple = tuple(created_tuple)
s.execute("INSERT INTO mytable (k, v) VALUES (%s, %s)", (i, written_tuple))
result = s.execute("SELECT v FROM mytable WHERE k=%s", (i,))[0]
self.assertEqual(response_tuple, result.v)
def test_unicode_query_string(self):
c = Cluster(protocol_version=PROTOCOL_VERSION)