Refactored/Added some tests to CQLEngine

This commit is contained in:
bjmb 2017-04-11 16:28:13 -04:00
parent 941dab462b
commit 7bc0bfd8cb
9 changed files with 459 additions and 189 deletions

View File

@ -337,7 +337,6 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connect
query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
elif isinstance(query, six.string_types):
query = SimpleStatement(query, consistency_level=consistency_level)
log.debug(format_log_context(query.query_string, connection=connection))
result = conn.session.execute(query, params, timeout=timeout)

View File

@ -21,7 +21,7 @@ except ImportError:
from cassandra import ConsistencyLevel
from cassandra.cqlengine import connection
from cassandra.cqlengine.management import create_keyspace_simple, CQLENG_ALLOW_SCHEMA_MANAGEMENT
from cassandra.cqlengine.management import create_keyspace_simple, drop_keyspace, CQLENG_ALLOW_SCHEMA_MANAGEMENT
import cassandra
from tests.integration import get_server_versions, use_single_node, PROTOCOL_VERSION, CASSANDRA_IP, set_default_cass_ip
@ -45,7 +45,6 @@ def setup_package():
def teardown_package():
connection.unregister_connection("default")
def is_prepend_reversed():
# do we have https://issues.apache.org/jira/browse/CASSANDRA-8733 ?
ver, _ = get_server_versions()

View File

@ -19,7 +19,20 @@ except ImportError:
import sys
from cassandra.cqlengine.connection import get_session
from cassandra.cqlengine.models import Model
from cassandra.cqlengine import columns
from uuid import uuid4
class TestQueryUpdateModel(Model):
partition = columns.UUID(primary_key=True, default=uuid4)
cluster = columns.Integer(primary_key=True)
count = columns.Integer(required=False)
text = columns.Text(required=False, index=True)
text_set = columns.Set(columns.Text, required=False)
text_list = columns.List(columns.Text, required=False)
text_map = columns.Map(columns.Text, columns.Text, required=False)
class BaseCassEngTestCase(unittest.TestCase):

View File

@ -18,30 +18,21 @@ except ImportError:
import unittest # noqa
import sys
from datetime import datetime, timedelta, date, tzinfo
from datetime import datetime, timedelta, date, tzinfo, time
from decimal import Decimal as D
from uuid import uuid4, uuid1
from cassandra import InvalidRequest
from cassandra.cqlengine.columns import TimeUUID
from cassandra.cqlengine.columns import Ascii
from cassandra.cqlengine.columns import Text
from cassandra.cqlengine.columns import Integer
from cassandra.cqlengine.columns import BigInt
from cassandra.cqlengine.columns import VarInt
from cassandra.cqlengine.columns import DateTime
from cassandra.cqlengine.columns import Date
from cassandra.cqlengine.columns import UUID
from cassandra.cqlengine.columns import Boolean
from cassandra.cqlengine.columns import Decimal
from cassandra.cqlengine.columns import Inet
from cassandra.cqlengine.columns import TimeUUID, Ascii, Text, Integer, BigInt, VarInt, DateTime, Date, UUID, Boolean, \
Decimal, Inet, Time, UserDefinedType, Map, List, Set, Tuple, Double, Float
from cassandra.cqlengine.connection import execute
from cassandra.cqlengine.management import sync_table, drop_table
from cassandra.cqlengine.models import Model, ValidationError
from cassandra.cqlengine.usertype import UserType
from cassandra import util
from tests.integration import PROTOCOL_VERSION
from tests.integration.cqlengine.base import BaseCassEngTestCase
from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel
class TestDatetime(BaseCassEngTestCase):
@ -62,7 +53,7 @@ class TestDatetime(BaseCassEngTestCase):
now = datetime.now()
self.DatetimeTest.objects.create(test_id=0, created_at=now)
dt2 = self.DatetimeTest.objects(test_id=0).first()
assert dt2.created_at.timetuple()[:6] == now.timetuple()[:6]
self.assertEqual(dt2.created_at.timetuple()[:6], now.timetuple()[:6])
def test_datetime_tzinfo_io(self):
class TZ(tzinfo):
@ -74,21 +65,27 @@ class TestDatetime(BaseCassEngTestCase):
now = datetime(1982, 1, 1, tzinfo=TZ())
dt = self.DatetimeTest.objects.create(test_id=1, created_at=now)
dt2 = self.DatetimeTest.objects(test_id=1).first()
assert dt2.created_at.timetuple()[:6] == (now + timedelta(hours=1)).timetuple()[:6]
self.assertEqual(dt2.created_at.timetuple()[:6], (now + timedelta(hours=1)).timetuple()[:6])
def test_datetime_date_support(self):
today = date.today()
self.DatetimeTest.objects.create(test_id=2, created_at=today)
dt2 = self.DatetimeTest.objects(test_id=2).first()
assert dt2.created_at.isoformat() == datetime(today.year, today.month, today.day).isoformat()
self.assertEqual(dt2.created_at.isoformat(), datetime(today.year, today.month, today.day).isoformat())
result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2).first()
self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time()))
result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2, created_at=today).first()
self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time()))
def test_datetime_none(self):
dt = self.DatetimeTest.objects.create(test_id=3, created_at=None)
dt2 = self.DatetimeTest.objects(test_id=3).first()
assert dt2.created_at is None
self.assertIsNone(dt2.created_at)
dts = self.DatetimeTest.objects.filter(test_id=3).values_list('created_at')
assert dts[0][0] is None
self.assertIsNone(dts[0][0])
def test_datetime_invalid(self):
dt_value= 'INVALID'
@ -99,13 +96,13 @@ class TestDatetime(BaseCassEngTestCase):
dt_value = 1454520554
self.DatetimeTest.objects.create(test_id=5, created_at=dt_value)
dt2 = self.DatetimeTest.objects(test_id=5).first()
assert dt2.created_at == datetime.utcfromtimestamp(dt_value)
self.assertEqual(dt2.created_at, datetime.utcfromtimestamp(dt_value))
def test_datetime_large(self):
dt_value = datetime(2038, 12, 31, 10, 10, 10, 123000)
self.DatetimeTest.objects.create(test_id=6, created_at=dt_value)
dt2 = self.DatetimeTest.objects(test_id=6).first()
assert dt2.created_at == dt_value
self.assertEqual(dt2.created_at, dt_value)
def test_datetime_truncate_microseconds(self):
"""
@ -187,49 +184,170 @@ class TestVarInt(BaseCassEngTestCase):
int2 = self.VarIntTest.objects(test_id=0).first()
self.assertEqual(int1.bignum, int2.bignum)
with self.assertRaises(ValidationError):
self.VarIntTest.objects.create(test_id=0, bignum="not_a_number")
class TestDate(BaseCassEngTestCase):
class DateTest(Model):
test_id = Integer(primary_key=True)
created_at = Date()
class DataType():
@classmethod
def setUpClass(cls):
if PROTOCOL_VERSION < 4:
return
sync_table(cls.DateTest)
class DataTypeTest(Model):
test_id = Integer(primary_key=True)
class_param = cls.db_klass()
cls.model_class = DataTypeTest
sync_table(cls.model_class)
@classmethod
def tearDownClass(cls):
if PROTOCOL_VERSION < 4:
return
drop_table(cls.DateTest)
drop_table(cls.model_class)
def setUp(self):
if PROTOCOL_VERSION < 4:
raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION))
def test_date_io(self):
today = date.today()
self.DateTest.objects.create(test_id=0, created_at=today)
result = self.DateTest.objects(test_id=0).first()
self.assertEqual(result.created_at, util.Date(today))
def _check_value_is_correct_in_db(self, value):
if value is None:
result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first()
self.assertIsNone(result.class_param)
else:
if not isinstance(value, self.python_klass):
value_to_compare = self.python_klass(value)
else:
value_to_compare = value
result = self.model_class.objects(test_id=0).first()
self.assertIsInstance(result.class_param, self.python_klass)
self.assertEqual(result.class_param, value_to_compare)
result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first()
self.assertIsInstance(result.class_param, self.python_klass)
self.assertEqual(result.class_param, value_to_compare)
result = self.model_class.objects.all().allow_filtering().filter(test_id=0, class_param=value).first()
self.assertIsInstance(result.class_param, self.python_klass)
self.assertEqual(result.class_param, value_to_compare)
return result
def test_date_io_using_datetime(self):
now = datetime.utcnow()
self.DateTest.objects.create(test_id=0, created_at=now)
result = self.DateTest.objects(test_id=0).first()
self.assertIsInstance(result.created_at, util.Date)
self.assertEqual(result.created_at, util.Date(now))
first_value = self.first_value
second_value = self.second_value
third_value = self.third_value
self.model_class.objects.create(test_id=0, class_param=first_value)
result = self._check_value_is_correct_in_db(first_value)
result.delete()
self.model_class.objects.create(test_id=0, class_param=second_value)
result = self._check_value_is_correct_in_db(second_value)
result.update(class_param=third_value).save()
result = self._check_value_is_correct_in_db(third_value)
result.update(class_param=None).save()
self._check_value_is_correct_in_db(None)
def test_date_none(self):
self.DateTest.objects.create(test_id=1, created_at=None)
dt2 = self.DateTest.objects(test_id=1).first()
assert dt2.created_at is None
self.model_class.objects.create(test_id=1, class_param=None)
dt2 = self.model_class.objects(test_id=1).first()
self.assertIsNone(dt2.class_param)
dts = self.DateTest.objects(test_id=1).values_list('created_at')
assert dts[0][0] is None
dts = self.model_class.objects(test_id=1).values_list('class_param')
self.assertIsNone(dts[0][0])
class TestDate(DataType, BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
cls.db_klass, cls.python_klass = Date, util.Date
cls.first_value, cls.second_value, cls.third_value = \
datetime.utcnow(), util.Date(datetime(1, 1, 1)), datetime(1, 1, 2)
super(TestDate, cls).setUpClass()
class TestTime(DataType, BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
cls.db_klass, cls.python_klass = Time, util.Time
cls.first_value, cls.second_value, cls.third_value = \
time(2, 12, 7, 48), util.Time(time(2, 12, 7, 49)), time(2, 12, 7, 50)
super(TestTime, cls).setUpClass()
class TestDateTime(DataType, BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
cls.db_klass, cls.python_klass = DateTime, datetime
cls.first_value, cls.second_value, cls.third_value = \
datetime(2017, 4, 13, 18, 34, 24, 317000), datetime(1, 1, 1), datetime(1, 1, 2)
super(TestDateTime, cls).setUpClass()
class TestBoolean(DataType, BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
cls.db_klass, cls.python_klass = Boolean, bool
cls.first_value, cls.second_value, cls.third_value = True, False, True
super(TestBoolean, cls).setUpClass()
class User(UserType):
# We use Date and Time to ensure to_python
# is called for these columns
age = Integer()
date_param = Date()
map_param = Map(Integer, Time)
list_param = List(Date)
set_param = Set(Date)
tuple_param = Tuple(Date, Decimal, Boolean, VarInt, Double, UUID)
class UserModel(Model):
test_id = Integer(primary_key=True)
class_param = UserDefinedType(User)
class TestUDT(DataType, BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
cls.db_klass, cls.python_klass = UserDefinedType, User
cls.first_value = User(
age=1,
date_param=datetime.utcnow(),
map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))},
list_param=[datetime(1, 1, 2), datetime(1, 1, 3)],
set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 1)))),
tuple_param=(datetime(1, 1, 3), 2, False, 1, 2.324, uuid4())
)
cls.second_value = User(
age=1,
date_param=datetime.utcnow(),
map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))},
list_param=[datetime(1, 1, 2), datetime(1, 2, 3)],
set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 1)))),
tuple_param=(datetime(1, 1, 2), 2, False, 1, 2.324, uuid4())
)
cls.third_value = User(
age=2,
date_param=datetime.utcnow(),
map_param={1: time(2, 12, 7, 51), 2: util.Time(time(2, 12, 7, 49))},
list_param=[datetime(1, 1, 2), datetime(1, 1, 4)],
set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 2)))),
tuple_param=(datetime(1, 1, 2), 3, False, 1, 2.3214, uuid4())
)
cls.model_class = UserModel
sync_table(cls.model_class)
class TestDecimal(BaseCassEngTestCase):
@ -319,7 +437,7 @@ class TestInteger(BaseCassEngTestCase):
class IntegerTest(Model):
test_id = UUID(primary_key=True, default=lambda:uuid4())
value = Integer(default=0, required=True)
value = Integer(default=0, required=True)
def test_default_zero_fields_validate(self):
""" Tests that integer columns with a default value of 0 validate """
@ -644,4 +762,3 @@ class TestInet(BaseCassEngTestCase):
# TODO: presently this only tests that the server blows it up. Is there supposed to be local validation?
with self.assertRaises(InvalidRequest):
self.InetTestModel.create(address="what is going on here?")

View File

@ -24,11 +24,45 @@ from uuid import UUID, uuid4
from cassandra.cqlengine.models import Model
from cassandra.cqlengine.usertype import UserType, UserTypeDefinitionException
from cassandra.cqlengine import columns, connection
from cassandra.cqlengine.management import sync_table, sync_type, create_keyspace_simple, drop_keyspace, drop_table
from cassandra.cqlengine.management import sync_table, drop_table, sync_type, create_keyspace_simple, drop_keyspace
from cassandra.cqlengine import ValidationError
from cassandra.util import Date, Time
from tests.integration import PROTOCOL_VERSION
from tests.integration.cqlengine.base import BaseCassEngTestCase
from tests.integration.cqlengine import DEFAULT_KEYSPACE
class User(UserType):
age = columns.Integer()
name = columns.Text()
class UserModel(Model):
id = columns.Integer(primary_key=True)
info = columns.UserDefinedType(User)
class AllDatatypes(UserType):
a = columns.Ascii()
b = columns.BigInt()
c = columns.Blob()
d = columns.Boolean()
e = columns.DateTime()
f = columns.Decimal()
g = columns.Double()
h = columns.Float()
i = columns.Inet()
j = columns.Integer()
k = columns.Text()
l = columns.TimeUUID()
m = columns.UUID()
n = columns.VarInt()
class AllDatatypesModel(Model):
id = columns.Integer(primary_key=True)
data = columns.UserDefinedType(AllDatatypes)
class UserDefinedTypeTests(BaseCassEngTestCase):
@ -42,7 +76,7 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
age = columns.Integer()
name = columns.Text()
sync_type("cqlengine_test", User)
sync_type(DEFAULT_KEYSPACE, User)
user = User(age=42, name="John")
self.assertEqual(42, user.age)
self.assertEqual("John", user.name)
@ -53,8 +87,10 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
name = columns.Text()
gender = columns.Text()
sync_type("cqlengine_test", User)
user = User(age=42, name="John", gender="male")
sync_type(DEFAULT_KEYSPACE, User)
user = User(age=42)
user["name"] = "John"
user["gender"] = "male"
self.assertEqual(42, user.age)
self.assertEqual("John", user.name)
self.assertEqual("male", user.gender)
@ -64,19 +100,12 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
age = columns.Integer()
name = columns.Text()
sync_type("cqlengine_test", User)
sync_type(DEFAULT_KEYSPACE, User)
user = User(age=42, name="John", gender="male")
with self.assertRaises(AttributeError):
user.gender
def test_can_insert_udts(self):
class User(UserType):
age = columns.Integer()
name = columns.Text()
class UserModel(Model):
id = columns.Integer(primary_key=True)
info = columns.UserDefinedType(User)
sync_table(UserModel)
@ -90,16 +119,9 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
self.assertTrue(type(john.info) is User)
self.assertEqual(42, john.info.age)
self.assertEqual("John", john.info.name)
drop_table(UserModel)
def test_can_update_udts(self):
class User(UserType):
age = columns.Integer()
name = columns.Text()
class UserModel(Model):
id = columns.Integer(primary_key=True)
info = columns.UserDefinedType(User)
sync_table(UserModel)
user = User(age=42, name="John")
@ -113,18 +135,11 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
created_user.update()
mary_info = UserModel.objects().first().info
self.assertEqual(22, mary_info.age)
self.assertEqual("Mary", mary_info.name)
self.assertEqual(22, mary_info["age"])
self.assertEqual("Mary", mary_info["name"])
drop_table(UserModel)
def test_can_update_udts_with_nones(self):
class User(UserType):
age = columns.Integer()
name = columns.Text()
class UserModel(Model):
id = columns.Integer(primary_key=True)
info = columns.UserDefinedType(User)
sync_table(UserModel)
user = User(age=42, name="John")
@ -139,45 +154,43 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
john_info = UserModel.objects().first().info
self.assertIsNone(john_info)
drop_table(UserModel)
def test_can_create_same_udt_different_keyspaces(self):
class User(UserType):
age = columns.Integer()
name = columns.Text()
sync_type("cqlengine_test", User)
sync_type(DEFAULT_KEYSPACE, User)
create_keyspace_simple("simplex", 1)
sync_type("simplex", User)
drop_keyspace("simplex")
def test_can_insert_partial_udts(self):
class User(UserType):
class UserGender(UserType):
age = columns.Integer()
name = columns.Text()
gender = columns.Text()
class UserModel(Model):
class UserModelGender(Model):
id = columns.Integer(primary_key=True)
info = columns.UserDefinedType(User)
info = columns.UserDefinedType(UserGender)
sync_table(UserModel)
sync_table(UserModelGender)
user = User(age=42, name="John")
UserModel.create(id=0, info=user)
user = UserGender(age=42, name="John")
UserModelGender.create(id=0, info=user)
john_info = UserModel.objects().first().info
john_info = UserModelGender.objects().first().info
self.assertEqual(42, john_info.age)
self.assertEqual("John", john_info.name)
self.assertIsNone(john_info.gender)
user = User(age=42)
UserModel.create(id=0, info=user)
user = UserGender(age=42)
UserModelGender.create(id=0, info=user)
john_info = UserModel.objects().first().info
john_info = UserModelGender.objects().first().info
self.assertEqual(42, john_info.age)
self.assertIsNone(john_info.name)
self.assertIsNone(john_info.gender)
drop_table(UserModelGender)
def test_can_insert_nested_udts(self):
class Depth_0(UserType):
@ -215,6 +228,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
self.assertEqual(udts[2], output.v_2)
self.assertEqual(udts[3], output.v_3)
drop_table(DepthModel)
def test_can_insert_udts_with_nones(self):
"""
Test for inserting all column types as empty into a UserType as None's
@ -230,27 +245,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
@test_category data_types:udt
"""
class AllDatatypes(UserType):
a = columns.Ascii()
b = columns.BigInt()
c = columns.Blob()
d = columns.Boolean()
e = columns.DateTime()
f = columns.Decimal()
g = columns.Double()
h = columns.Float()
i = columns.Inet()
j = columns.Integer()
k = columns.Text()
l = columns.TimeUUID()
m = columns.UUID()
n = columns.VarInt()
class AllDatatypesModel(Model):
id = columns.Integer(primary_key=True)
data = columns.UserDefinedType(AllDatatypes)
sync_table(AllDatatypesModel)
input = AllDatatypes(a=None, b=None, c=None, d=None, e=None, f=None, g=None, h=None, i=None, j=None, k=None,
@ -262,6 +256,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
output = AllDatatypesModel.objects().first().data
self.assertEqual(input, output)
drop_table(AllDatatypesModel)
def test_can_insert_udts_with_all_datatypes(self):
"""
Test for inserting all column types into a UserType
@ -277,27 +273,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
@test_category data_types:udt
"""
class AllDatatypes(UserType):
a = columns.Ascii()
b = columns.BigInt()
c = columns.Blob()
d = columns.Boolean()
e = columns.DateTime()
f = columns.Decimal()
g = columns.Double()
h = columns.Float()
i = columns.Inet()
j = columns.Integer()
k = columns.Text()
l = columns.TimeUUID()
m = columns.UUID()
n = columns.VarInt()
class AllDatatypesModel(Model):
id = columns.Integer(primary_key=True)
data = columns.UserDefinedType(AllDatatypes)
sync_table(AllDatatypesModel)
input = AllDatatypes(a='ascii', b=2 ** 63 - 1, c=bytearray(b'hello world'), d=True,
@ -313,6 +288,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
for i in range(ord('a'), ord('a') + 14):
self.assertEqual(input[chr(i)], output[chr(i)])
drop_table(AllDatatypesModel)
def test_can_insert_udts_protocol_v4_datatypes(self):
"""
Test for inserting all protocol v4 column types into a UserType
@ -354,6 +331,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
for i in range(ord('a'), ord('a') + 3):
self.assertEqual(input[chr(i)], output[chr(i)])
drop_table(Allv4DatatypesModel)
def test_nested_udts_inserts(self):
"""
Test for inserting collections of user types using cql engine.
@ -394,6 +373,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
names_output = Container.objects().first().names
self.assertEqual(names_output, names)
drop_table(Container)
def test_udts_with_unicode(self):
"""
Test for inserting models with unicode and udt columns.
@ -410,10 +391,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
ascii_name = 'normal name'
unicode_name = u'Fran\u00E7ois'
class User(UserType):
age = columns.Integer()
name = columns.Text()
class UserModelText(Model):
id = columns.Text(primary_key=True)
info = columns.UserDefinedType(User)
@ -427,10 +404,9 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
UserModelText.create(id=unicode_name, info=user_template_ascii)
UserModelText.create(id=unicode_name, info=user_template_unicode)
drop_table(UserModelText)
def test_register_default_keyspace(self):
class User(UserType):
age = columns.Integer()
name = columns.Text()
from cassandra.cqlengine import models
from cassandra.cqlengine import connection
@ -495,6 +471,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
self.assertEqual(info.a, age)
self.assertEqual(info.n, name)
drop_table(TheModel)
def test_db_field_overload(self):
"""
Tests for db_field UserTypeDefinitionException
@ -520,9 +498,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
def test_set_udt_fields(self):
# PYTHON-502
class User(UserType):
age = columns.Integer()
name = columns.Text()
u = User()
u.age = 20
@ -562,3 +537,63 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
self.assertEqual(t.nested[0].default_text, "default text")
self.assertIsNotNone(t.simple.test_id)
self.assertEqual(t.simple.default_text, "default text")
drop_table(OuterModel)
def test_udt_validate(self):
"""
Test to verify restrictions are honored and that validate is called
for each member of the UDT when an updated is attempted
@since 3.10
@jira_ticket PYTHON-505
@expected_result a validation error is arisen due to the name being
too long
@test_category data_types:object_mapper
"""
class UserValidate(UserType):
age = columns.Integer()
name = columns.Text(max_length=2)
class UserModelValidate(Model):
id = columns.Integer(primary_key=True)
info = columns.UserDefinedType(UserValidate)
sync_table(UserModelValidate)
user = UserValidate(age=1, name="Robert")
item = UserModelValidate(id=1, info=user)
with self.assertRaises(ValidationError):
item.save()
drop_table(UserModelValidate)
def test_udt_validate_with_default(self):
"""
Test to verify restrictions are honored and that validate is called
on the default value
@since 3.10
@jira_ticket PYTHON-505
@expected_result a validation error is arisen due to the name being
too long
@test_category data_types:object_mapper
"""
class UserValidateDefault(UserType):
age = columns.Integer()
name = columns.Text(max_length=2, default="Robert")
class UserModelValidateDefault(Model):
id = columns.Integer(primary_key=True)
info = columns.UserDefinedType(UserValidateDefault)
sync_table(UserModelValidateDefault)
user = UserValidateDefault(age=1)
item = UserModelValidateDefault(id=1, info=user)
with self.assertRaises(ValidationError):
item.save()
drop_table(UserModelValidateDefault)

View File

@ -21,7 +21,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase
from cassandra.cqlengine.models import Model
from cassandra.cqlengine import columns
from cassandra.cqlengine.management import sync_table, drop_table
from cassandra.cqlengine.usertype import UserType
class TestUpdateModel(Model):
@ -147,19 +147,40 @@ class ModelUpdateTests(BaseCassEngTestCase):
m0.update(partition=uuid4())
class UDT(UserType):
age = columns.Integer()
mf = columns.Map(columns.Integer, columns.Integer)
dummy_udt = columns.Integer(default=42)
class ModelWithDefault(Model):
id = columns.Integer(primary_key=True)
mf = columns.Map(columns.Integer, columns.Integer)
dummy = columns.Integer(default=42)
id = columns.Integer(primary_key=True)
mf = columns.Map(columns.Integer, columns.Integer)
dummy = columns.Integer(default=42)
udt = columns.UserDefinedType(UDT)
udt_default = columns.UserDefinedType(UDT, default=UDT(age=1, mf={2:2}))
class UDTWithDefault(UserType):
age = columns.Integer()
mf = columns.Map(columns.Integer, columns.Integer, default={2:2})
dummy_udt = columns.Integer(default=42)
class ModelWithDefaultCollection(Model):
id = columns.Integer(primary_key=True)
mf = columns.Map(columns.Integer, columns.Integer, default={2:2})
dummy = columns.Integer(default=42)
id = columns.Integer(primary_key=True)
mf = columns.Map(columns.Integer, columns.Integer, default={2:2})
dummy = columns.Integer(default=42)
udt = columns.UserDefinedType(UDT)
udt_default = columns.UserDefinedType(UDT, default=UDT(age=1, mf={2: 2}))
class ModelWithDefaultTests(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
cls.default_udt = UDT(age=1, mf={2:2}, dummy_udt=42)
def setUp(self):
sync_table(ModelWithDefault)
@ -176,17 +197,19 @@ class ModelWithDefaultTests(BaseCassEngTestCase):
@test_category object_mapper
"""
initial = ModelWithDefault(id=1, mf={0: 0}, dummy=0)
first_udt = UDT(age=1, mf={2:2}, dummy_udt=0)
initial = ModelWithDefault(id=1, mf={0: 0}, dummy=0, udt=first_udt, udt_default=first_udt)
initial.save()
self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
{'id': 1, 'dummy': 0, 'mf': {0: 0}})
{'id': 1, 'dummy': 0, 'mf': {0: 0}, "udt": first_udt, "udt_default": first_udt})
second_udt = UDT(age=1, mf={3: 3}, dummy_udt=12)
second = ModelWithDefault(id=1)
second.update(mf={0: 1})
second.update(mf={0: 1}, udt=second_udt)
self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
{'id': 1, 'dummy': 0, 'mf': {0: 1}})
{'id': 1, 'dummy': 0, 'mf': {0: 1}, "udt": second_udt, "udt_default": first_udt})
def test_value_is_written_if_is_default(self):
"""
@ -202,10 +225,11 @@ class ModelWithDefaultTests(BaseCassEngTestCase):
initial = ModelWithDefault(id=1)
initial.mf = {0: 0}
initial.dummy = 42
initial.udt_default = self.default_udt
initial.update()
self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
{'id': 1, 'dummy': 42, 'mf': {0: 0}})
{'id': 1, 'dummy': 42, 'mf': {0: 0}, "udt": None, "udt_default": self.default_udt})
def test_null_update_is_respected(self):
"""
@ -223,10 +247,11 @@ class ModelWithDefaultTests(BaseCassEngTestCase):
q = ModelWithDefault.objects.all().allow_filtering()
obj = q.filter(id=1).get()
obj.update(dummy=None)
updated_udt = UDT(age=1, mf={2:2}, dummy_udt=None)
obj.update(dummy=None, udt_default=updated_udt)
self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
{'id': 1, 'dummy': None, 'mf': {0: 0}})
{'id': 1, 'dummy': None, 'mf': {0: 0}, "udt": None, "udt_default": updated_udt})
def test_only_set_values_is_updated(self):
"""
@ -244,28 +269,32 @@ class ModelWithDefaultTests(BaseCassEngTestCase):
item = ModelWithDefault.filter(id=1).first()
ModelWithDefault.objects(id=1).delete()
item.mf = {1: 2}
udt, default_udt = UDT(age=1, mf={2:3}), UDT(age=1, mf={2:3})
item.udt, item.default_udt = udt, default_udt
item.save()
self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
{'id': 1, 'dummy': None, 'mf': {1: 2}})
{'id': 1, 'dummy': None, 'mf': {1: 2}, "udt": udt, "udt_default": default_udt})
def test_collections(self):
"""
Test the updates work as expected when an object is deleted
Test the updates work as expected on Map objects
@since 3.9
@jira_ticket PYTHON-657
@expected_result the non updated column is None and the
updated column has the set value
@expected_result the row is updated when the Map object is
reduced
@test_category object_mapper
"""
ModelWithDefault.create(id=1, mf={1: 1, 2: 1}, dummy=1).save()
udt, udt_default = UDT(age=1, mf={1: 1, 2: 1}), UDT(age=1, mf={1: 1, 2: 1})
ModelWithDefault.create(id=1, mf={1: 1, 2: 1}, dummy=1, udt=udt, udt_default=udt_default).save()
item = ModelWithDefault.filter(id=1).first()
udt, udt_default = UDT(age=1, mf={2: 1}), UDT(age=1, mf={2: 1})
item.update(mf={2:1})
self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
{'id': 1, 'dummy': 1, 'mf': {2: 1}})
{'id': 1, 'dummy': 1, 'mf': {2: 1}, "udt": udt, "udt_default": udt_default})
def test_collection_with_default(self):
"""
@ -278,24 +307,32 @@ class ModelWithDefaultTests(BaseCassEngTestCase):
@test_category object_mapper
"""
sync_table(ModelWithDefaultCollection)
item = ModelWithDefaultCollection.create(id=1, mf={1: 1}, dummy=1).save()
self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
{'id': 1, 'dummy': 1, 'mf': {1: 1}})
item.update(mf={2: 2})
self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
{'id': 1, 'dummy': 1, 'mf': {2: 2}})
udt, udt_default = UDT(age=1, mf={6: 6}), UDT(age=1, mf={6: 6})
item.update(mf=None)
item = ModelWithDefaultCollection.create(id=1, mf={1: 1}, dummy=1, udt=udt, udt_default=udt_default).save()
self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
{'id': 1, 'dummy': 1, 'mf': {}})
{'id': 1, 'dummy': 1, 'mf': {1: 1}, "udt": udt, "udt_default": udt_default})
udt, udt_default = UDT(age=1, mf={5: 5}), UDT(age=1, mf={5: 5})
item.update(mf={2: 2}, udt=udt, udt_default=udt_default)
self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
{'id': 1, 'dummy': 1, 'mf': {2: 2}, "udt": udt, "udt_default": udt_default})
udt, udt_default = UDT(age=1, mf=None), UDT(age=1, mf=None)
expected_udt, expected_udt_default = UDT(age=1, mf={}), UDT(age=1, mf={})
item.update(mf=None, udt=udt, udt_default=udt_default)
self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
{'id': 1, 'dummy': 1, 'mf': {}, "udt": expected_udt, "udt_default": expected_udt_default})
udt_default = UDT(age=1, mf=None), UDT(age=1, mf={5:5})
item = ModelWithDefaultCollection.create(id=2, dummy=2).save()
self.assertEqual(ModelWithDefaultCollection.objects().all().get(id=2)._as_dict(),
{'id': 2, 'dummy': 2, 'mf': {2: 2}})
{'id': 2, 'dummy': 2, 'mf': {2: 2}, "udt": None, "udt_default": udt_default})
item.update(mf={1: 1, 4: 4})
udt, udt_default = UDT(age=1, mf={1: 1, 6: 6}), UDT(age=1, mf={1: 1, 6: 6})
item.update(mf={1: 1, 4: 4}, udt=udt, udt_default=udt_default)
self.assertEqual(ModelWithDefaultCollection.objects().all().get(id=2)._as_dict(),
{'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}})
{'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default})
drop_table(ModelWithDefaultCollection)

View File

@ -150,7 +150,7 @@ class TestTokenFunction(BaseCassEngTestCase):
TokenTestModel.create(key=i, val=i)
named = NamedTable(DEFAULT_KEYSPACE, TokenTestModel.__table_name__)
query = named.objects.all().limit(1)
query = named.all().limit(1)
first_page = list(query)
last = first_page[-1]
self.assertTrue(len(first_page) is 1)

View File

@ -18,21 +18,12 @@ from cassandra.cqlengine import ValidationError
from cassandra.cqlengine.models import Model
from cassandra.cqlengine.management import sync_table, drop_table
from cassandra.cqlengine import columns
from tests.integration.cqlengine import is_prepend_reversed
from tests.integration.cqlengine.base import BaseCassEngTestCase
from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel
from tests.integration.cqlengine import execute_count
from tests.integration import greaterthancass20
class TestQueryUpdateModel(Model):
partition = columns.UUID(primary_key=True, default=uuid4)
cluster = columns.Integer(primary_key=True)
count = columns.Integer(required=False)
text = columns.Text(required=False, index=True)
text_set = columns.Set(columns.Text, required=False)
text_list = columns.List(columns.Text, required=False)
text_map = columns.Map(columns.Text, columns.Text, required=False)
class QueryUpdateTests(BaseCassEngTestCase):

View File

@ -16,8 +16,19 @@ try:
except ImportError:
import unittest # noqa
from uuid import uuid4
from cassandra.query import FETCH_SIZE_UNSET
from cassandra.cqlengine.statements import BaseCQLStatement
from cassandra.cqlengine.management import sync_table, drop_table
from cassandra.cqlengine.statements import InsertStatement, UpdateStatement, SelectStatement, DeleteStatement, \
WhereClause
from cassandra.cqlengine.operators import EqualsOperator
from cassandra.cqlengine.columns import Column
from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel
from tests.integration.cqlengine import DEFAULT_KEYSPACE
from cassandra.cqlengine.connection import execute
class BaseStatementTest(unittest.TestCase):
@ -32,3 +43,71 @@ class BaseStatementTest(unittest.TestCase):
stmt = BaseCQLStatement('table', None)
self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET)
class ExecuteStatementTest(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(ExecuteStatementTest, cls).setUpClass()
sync_table(TestQueryUpdateModel)
cls.table_name = '{0}.test_query_update_model'.format(DEFAULT_KEYSPACE)
@classmethod
def tearDownClass(cls):
super(ExecuteStatementTest, cls).tearDownClass()
drop_table(TestQueryUpdateModel)
def _verify_statement(self, original):
st = SelectStatement(self.table_name)
result = execute(st)
response = result[0]
for assignment in original.assignments:
self.assertEqual(response[assignment.field], assignment.value)
self.assertEqual(len(response), 7)
def test_insert_statement_execute(self):
"""
Test to verify the execution of BaseCQLStatements using connection.execute
@since 3.10
@jira_ticket PYTHON-505
@expected_result inserts a row in C*, updates the rows and then deletes
all the rows using BaseCQLStatements
@test_category data_types:object_mapper
"""
partition = uuid4()
cluster = 1
#Verifying insert statement
st = InsertStatement(self.table_name)
st.add_assignment(Column(db_field='partition'), partition)
st.add_assignment(Column(db_field='cluster'), cluster)
st.add_assignment(Column(db_field='count'), 1)
st.add_assignment(Column(db_field='text'), "text_for_db")
st.add_assignment(Column(db_field='text_set'), set(("foo", "bar")))
st.add_assignment(Column(db_field='text_list'), ["foo", "bar"])
st.add_assignment(Column(db_field='text_map'), {"foo": '1', "bar": '2'})
execute(st)
self._verify_statement(st)
# Verifying update statement
where = [WhereClause('partition', EqualsOperator(), partition),
WhereClause('cluster', EqualsOperator(), cluster)]
st = UpdateStatement(self.table_name, where=where)
st.add_assignment(Column(db_field='count'), 2)
st.add_assignment(Column(db_field='text'), "text_for_db_update")
st.add_assignment(Column(db_field='text_set'), set(("foo_update", "bar_update")))
st.add_assignment(Column(db_field='text_list'), ["foo_update", "bar_update"])
st.add_assignment(Column(db_field='text_map'), {"foo": '3', "bar": '4'})
execute(st)
self._verify_statement(st)
# Verifying delete statement
execute(DeleteStatement(self.table_name, where=where))
self.assertEqual(TestQueryUpdateModel.objects.count(), 0)