271 lines
6.9 KiB
Python
271 lines
6.9 KiB
Python
# Copyright 2013-2016 DataStax, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
try:
|
|
import unittest2 as unittest
|
|
except ImportError:
|
|
import unittest # noqa
|
|
|
|
from datetime import datetime, timedelta, time
|
|
from decimal import Decimal
|
|
from uuid import uuid1, uuid4, UUID
|
|
import six
|
|
|
|
from cassandra.cqlengine import columns
|
|
from cassandra.cqlengine.management import sync_table
|
|
from cassandra.cqlengine.management import drop_table
|
|
from cassandra.cqlengine.models import Model
|
|
|
|
from cassandra.util import Date, Time
|
|
|
|
from tests.integration import PROTOCOL_VERSION
|
|
from tests.integration.cqlengine.base import BaseCassEngTestCase
|
|
|
|
|
|
class BaseColumnIOTest(BaseCassEngTestCase):
|
|
"""
|
|
Tests that values are come out of cassandra in the format we expect
|
|
|
|
To test a column type, subclass this test, define the column, and the primary key
|
|
and data values you want to test
|
|
"""
|
|
|
|
# The generated test model is assigned here
|
|
_generated_model = None
|
|
|
|
# the column we want to test
|
|
column = None
|
|
|
|
# the values we want to test against, you can
|
|
# use a single value, or multiple comma separated values
|
|
pkey_val = None
|
|
data_val = None
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super(BaseColumnIOTest, cls).setUpClass()
|
|
|
|
# if the test column hasn't been defined, bail out
|
|
if not cls.column:
|
|
return
|
|
|
|
# create a table with the given column
|
|
class IOTestModel(Model):
|
|
pkey = cls.column(primary_key=True)
|
|
data = cls.column()
|
|
|
|
cls._generated_model = IOTestModel
|
|
sync_table(cls._generated_model)
|
|
|
|
# tupleify the tested values
|
|
if not isinstance(cls.pkey_val, tuple):
|
|
cls.pkey_val = cls.pkey_val,
|
|
if not isinstance(cls.data_val, tuple):
|
|
cls.data_val = cls.data_val,
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super(BaseColumnIOTest, cls).tearDownClass()
|
|
if not cls.column:
|
|
return
|
|
drop_table(cls._generated_model)
|
|
|
|
def comparator_converter(self, val):
|
|
""" If you want to convert the original value used to compare the model vales """
|
|
return val
|
|
|
|
def test_column_io(self):
|
|
""" Tests the given models class creates and retrieves values as expected """
|
|
if not self.column:
|
|
return
|
|
for pkey, data in zip(self.pkey_val, self.data_val):
|
|
# create
|
|
m1 = self._generated_model.create(pkey=pkey, data=data)
|
|
|
|
# get
|
|
m2 = self._generated_model.get(pkey=pkey)
|
|
assert m1.pkey == m2.pkey == self.comparator_converter(pkey), self.column
|
|
assert m1.data == m2.data == self.comparator_converter(data), self.column
|
|
|
|
# delete
|
|
self._generated_model.filter(pkey=pkey).delete()
|
|
|
|
|
|
class TestBlobIO(BaseColumnIOTest):
|
|
|
|
column = columns.Blob
|
|
pkey_val = six.b('blake'), uuid4().bytes
|
|
data_val = six.b('eggleston'), uuid4().bytes
|
|
|
|
|
|
class TestBlobIO2(BaseColumnIOTest):
|
|
|
|
column = columns.Blob
|
|
pkey_val = bytearray(six.b('blake')), uuid4().bytes
|
|
data_val = bytearray(six.b('eggleston')), uuid4().bytes
|
|
|
|
|
|
class TestTextIO(BaseColumnIOTest):
|
|
|
|
column = columns.Text
|
|
pkey_val = 'bacon'
|
|
data_val = 'monkey'
|
|
|
|
|
|
class TestNonBinaryTextIO(BaseColumnIOTest):
|
|
|
|
column = columns.Text
|
|
pkey_val = 'bacon'
|
|
data_val = '0xmonkey'
|
|
|
|
|
|
class TestInteger(BaseColumnIOTest):
|
|
|
|
column = columns.Integer
|
|
pkey_val = 5
|
|
data_val = 6
|
|
|
|
|
|
class TestBigInt(BaseColumnIOTest):
|
|
|
|
column = columns.BigInt
|
|
pkey_val = 6
|
|
data_val = pow(2, 63) - 1
|
|
|
|
|
|
class TestDateTime(BaseColumnIOTest):
|
|
|
|
column = columns.DateTime
|
|
|
|
now = datetime(*datetime.now().timetuple()[:6])
|
|
pkey_val = now
|
|
data_val = now + timedelta(days=1)
|
|
|
|
|
|
class TestUUID(BaseColumnIOTest):
|
|
|
|
column = columns.UUID
|
|
|
|
pkey_val = str(uuid4()), uuid4()
|
|
data_val = str(uuid4()), uuid4()
|
|
|
|
def comparator_converter(self, val):
|
|
return val if isinstance(val, UUID) else UUID(val)
|
|
|
|
|
|
class TestTimeUUID(BaseColumnIOTest):
|
|
|
|
column = columns.TimeUUID
|
|
|
|
pkey_val = str(uuid1()), uuid1()
|
|
data_val = str(uuid1()), uuid1()
|
|
|
|
def comparator_converter(self, val):
|
|
return val if isinstance(val, UUID) else UUID(val)
|
|
|
|
|
|
class TestFloatIO(BaseColumnIOTest):
|
|
|
|
column = columns.Float
|
|
|
|
pkey_val = 4.75
|
|
data_val = -1.5
|
|
|
|
|
|
class TestDoubleIO(BaseColumnIOTest):
|
|
|
|
column = columns.Double
|
|
|
|
pkey_val = 3.14
|
|
data_val = -1982.11
|
|
|
|
|
|
class TestDecimalIO(BaseColumnIOTest):
|
|
|
|
column = columns.Decimal
|
|
|
|
pkey_val = Decimal('1.35'), 5, '2.4'
|
|
data_val = Decimal('0.005'), 3.5, '8'
|
|
|
|
def comparator_converter(self, val):
|
|
return Decimal(repr(val) if isinstance(val, float) else val)
|
|
|
|
|
|
class ProtocolV4Test(BaseColumnIOTest):
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if PROTOCOL_VERSION >= 4:
|
|
super(ProtocolV4Test, cls).setUpClass()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
if PROTOCOL_VERSION >= 4:
|
|
super(ProtocolV4Test, cls).tearDownClass()
|
|
|
|
class TestDate(ProtocolV4Test):
|
|
|
|
def setUp(self):
|
|
if PROTOCOL_VERSION < 4:
|
|
raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION))
|
|
|
|
super(TestDate, self).setUp()
|
|
|
|
column = columns.Date
|
|
|
|
now = Date(datetime.now().date())
|
|
pkey_val = now
|
|
data_val = Date(now.days_from_epoch + 1)
|
|
|
|
|
|
class TestTime(ProtocolV4Test):
|
|
|
|
def setUp(self):
|
|
if PROTOCOL_VERSION < 4:
|
|
raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION))
|
|
|
|
super(TestTime, self).setUp()
|
|
|
|
column = columns.Time
|
|
|
|
pkey_val = Time(time(2, 12, 7, 48))
|
|
data_val = Time(time(16, 47, 25, 7))
|
|
|
|
|
|
class TestSmallInt(ProtocolV4Test):
|
|
|
|
def setUp(self):
|
|
if PROTOCOL_VERSION < 4:
|
|
raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION))
|
|
|
|
super(TestSmallInt, self).setUp()
|
|
|
|
column = columns.SmallInt
|
|
|
|
pkey_val = 16768
|
|
data_val = 32523
|
|
|
|
|
|
class TestTinyInt(ProtocolV4Test):
|
|
|
|
def setUp(self):
|
|
if PROTOCOL_VERSION < 4:
|
|
raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION))
|
|
|
|
super(TestTinyInt, self).setUp()
|
|
|
|
column = columns.TinyInt
|
|
|
|
pkey_val = 1
|
|
data_val = 123
|