From 7bc0bfd8cbae84a62bd6334f6f74198b760522f7 Mon Sep 17 00:00:00 2001 From: bjmb Date: Tue, 11 Apr 2017 16:28:13 -0400 Subject: [PATCH] Refactored/Added some tests to CQLEngine --- cassandra/cqlengine/connection.py | 1 - tests/integration/cqlengine/__init__.py | 3 +- tests/integration/cqlengine/base.py | 13 + .../cqlengine/columns/test_validation.py | 207 ++++++++++++---- .../integration/cqlengine/model/test_udts.py | 229 ++++++++++-------- .../cqlengine/model/test_updates.py | 101 +++++--- .../cqlengine/query/test_queryoperators.py | 2 +- .../cqlengine/query/test_updates.py | 13 +- .../statements/test_base_statement.py | 79 ++++++ 9 files changed, 459 insertions(+), 189 deletions(-) diff --git a/cassandra/cqlengine/connection.py b/cassandra/cqlengine/connection.py index 8b764998..2f4890d2 100644 --- a/cassandra/cqlengine/connection.py +++ b/cassandra/cqlengine/connection.py @@ -337,7 +337,6 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connect query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size) elif isinstance(query, six.string_types): query = SimpleStatement(query, consistency_level=consistency_level) - log.debug(format_log_context(query.query_string, connection=connection)) result = conn.session.execute(query, params, timeout=timeout) diff --git a/tests/integration/cqlengine/__init__.py b/tests/integration/cqlengine/__init__.py index 8fd91ce5..2cb2e28d 100644 --- a/tests/integration/cqlengine/__init__.py +++ b/tests/integration/cqlengine/__init__.py @@ -21,7 +21,7 @@ except ImportError: from cassandra import ConsistencyLevel from cassandra.cqlengine import connection -from cassandra.cqlengine.management import create_keyspace_simple, CQLENG_ALLOW_SCHEMA_MANAGEMENT +from cassandra.cqlengine.management import create_keyspace_simple, drop_keyspace, CQLENG_ALLOW_SCHEMA_MANAGEMENT import cassandra from tests.integration import get_server_versions, use_single_node, PROTOCOL_VERSION, CASSANDRA_IP, set_default_cass_ip @@ -45,7 +45,6 @@ def setup_package(): def teardown_package(): connection.unregister_connection("default") - def is_prepend_reversed(): # do we have https://issues.apache.org/jira/browse/CASSANDRA-8733 ? ver, _ = get_server_versions() diff --git a/tests/integration/cqlengine/base.py b/tests/integration/cqlengine/base.py index 150b9ecf..7be58947 100644 --- a/tests/integration/cqlengine/base.py +++ b/tests/integration/cqlengine/base.py @@ -19,7 +19,20 @@ except ImportError: import sys from cassandra.cqlengine.connection import get_session +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns +from uuid import uuid4 + +class TestQueryUpdateModel(Model): + + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.Integer(primary_key=True) + count = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + text_set = columns.Set(columns.Text, required=False) + text_list = columns.List(columns.Text, required=False) + text_map = columns.Map(columns.Text, columns.Text, required=False) class BaseCassEngTestCase(unittest.TestCase): diff --git a/tests/integration/cqlengine/columns/test_validation.py b/tests/integration/cqlengine/columns/test_validation.py index 625b0721..f03bcd5e 100644 --- a/tests/integration/cqlengine/columns/test_validation.py +++ b/tests/integration/cqlengine/columns/test_validation.py @@ -18,30 +18,21 @@ except ImportError: import unittest # noqa import sys -from datetime import datetime, timedelta, date, tzinfo +from datetime import datetime, timedelta, date, tzinfo, time from decimal import Decimal as D from uuid import uuid4, uuid1 from cassandra import InvalidRequest -from cassandra.cqlengine.columns import TimeUUID -from cassandra.cqlengine.columns import Ascii -from cassandra.cqlengine.columns import Text -from cassandra.cqlengine.columns import Integer -from cassandra.cqlengine.columns import BigInt -from cassandra.cqlengine.columns import VarInt -from cassandra.cqlengine.columns import DateTime -from cassandra.cqlengine.columns import Date -from cassandra.cqlengine.columns import UUID -from cassandra.cqlengine.columns import Boolean -from cassandra.cqlengine.columns import Decimal -from cassandra.cqlengine.columns import Inet +from cassandra.cqlengine.columns import TimeUUID, Ascii, Text, Integer, BigInt, VarInt, DateTime, Date, UUID, Boolean, \ + Decimal, Inet, Time, UserDefinedType, Map, List, Set, Tuple, Double, Float from cassandra.cqlengine.connection import execute from cassandra.cqlengine.management import sync_table, drop_table from cassandra.cqlengine.models import Model, ValidationError +from cassandra.cqlengine.usertype import UserType from cassandra import util from tests.integration import PROTOCOL_VERSION -from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel class TestDatetime(BaseCassEngTestCase): @@ -62,7 +53,7 @@ class TestDatetime(BaseCassEngTestCase): now = datetime.now() self.DatetimeTest.objects.create(test_id=0, created_at=now) dt2 = self.DatetimeTest.objects(test_id=0).first() - assert dt2.created_at.timetuple()[:6] == now.timetuple()[:6] + self.assertEqual(dt2.created_at.timetuple()[:6], now.timetuple()[:6]) def test_datetime_tzinfo_io(self): class TZ(tzinfo): @@ -74,21 +65,27 @@ class TestDatetime(BaseCassEngTestCase): now = datetime(1982, 1, 1, tzinfo=TZ()) dt = self.DatetimeTest.objects.create(test_id=1, created_at=now) dt2 = self.DatetimeTest.objects(test_id=1).first() - assert dt2.created_at.timetuple()[:6] == (now + timedelta(hours=1)).timetuple()[:6] + self.assertEqual(dt2.created_at.timetuple()[:6], (now + timedelta(hours=1)).timetuple()[:6]) def test_datetime_date_support(self): today = date.today() self.DatetimeTest.objects.create(test_id=2, created_at=today) dt2 = self.DatetimeTest.objects(test_id=2).first() - assert dt2.created_at.isoformat() == datetime(today.year, today.month, today.day).isoformat() + self.assertEqual(dt2.created_at.isoformat(), datetime(today.year, today.month, today.day).isoformat()) + + result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2).first() + self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time())) + + result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2, created_at=today).first() + self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time())) def test_datetime_none(self): dt = self.DatetimeTest.objects.create(test_id=3, created_at=None) dt2 = self.DatetimeTest.objects(test_id=3).first() - assert dt2.created_at is None + self.assertIsNone(dt2.created_at) dts = self.DatetimeTest.objects.filter(test_id=3).values_list('created_at') - assert dts[0][0] is None + self.assertIsNone(dts[0][0]) def test_datetime_invalid(self): dt_value= 'INVALID' @@ -99,13 +96,13 @@ class TestDatetime(BaseCassEngTestCase): dt_value = 1454520554 self.DatetimeTest.objects.create(test_id=5, created_at=dt_value) dt2 = self.DatetimeTest.objects(test_id=5).first() - assert dt2.created_at == datetime.utcfromtimestamp(dt_value) + self.assertEqual(dt2.created_at, datetime.utcfromtimestamp(dt_value)) def test_datetime_large(self): dt_value = datetime(2038, 12, 31, 10, 10, 10, 123000) self.DatetimeTest.objects.create(test_id=6, created_at=dt_value) dt2 = self.DatetimeTest.objects(test_id=6).first() - assert dt2.created_at == dt_value + self.assertEqual(dt2.created_at, dt_value) def test_datetime_truncate_microseconds(self): """ @@ -187,49 +184,170 @@ class TestVarInt(BaseCassEngTestCase): int2 = self.VarIntTest.objects(test_id=0).first() self.assertEqual(int1.bignum, int2.bignum) + with self.assertRaises(ValidationError): + self.VarIntTest.objects.create(test_id=0, bignum="not_a_number") -class TestDate(BaseCassEngTestCase): - class DateTest(Model): - - test_id = Integer(primary_key=True) - created_at = Date() +class DataType(): @classmethod def setUpClass(cls): if PROTOCOL_VERSION < 4: return - sync_table(cls.DateTest) + + class DataTypeTest(Model): + test_id = Integer(primary_key=True) + class_param = cls.db_klass() + + cls.model_class = DataTypeTest + sync_table(cls.model_class) @classmethod def tearDownClass(cls): if PROTOCOL_VERSION < 4: return - drop_table(cls.DateTest) + drop_table(cls.model_class) def setUp(self): if PROTOCOL_VERSION < 4: raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) - def test_date_io(self): - today = date.today() - self.DateTest.objects.create(test_id=0, created_at=today) - result = self.DateTest.objects(test_id=0).first() - self.assertEqual(result.created_at, util.Date(today)) + def _check_value_is_correct_in_db(self, value): + + if value is None: + result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first() + self.assertIsNone(result.class_param) + + else: + if not isinstance(value, self.python_klass): + value_to_compare = self.python_klass(value) + else: + value_to_compare = value + + result = self.model_class.objects(test_id=0).first() + self.assertIsInstance(result.class_param, self.python_klass) + self.assertEqual(result.class_param, value_to_compare) + + result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first() + self.assertIsInstance(result.class_param, self.python_klass) + self.assertEqual(result.class_param, value_to_compare) + + result = self.model_class.objects.all().allow_filtering().filter(test_id=0, class_param=value).first() + self.assertIsInstance(result.class_param, self.python_klass) + self.assertEqual(result.class_param, value_to_compare) + + return result def test_date_io_using_datetime(self): - now = datetime.utcnow() - self.DateTest.objects.create(test_id=0, created_at=now) - result = self.DateTest.objects(test_id=0).first() - self.assertIsInstance(result.created_at, util.Date) - self.assertEqual(result.created_at, util.Date(now)) + first_value = self.first_value + second_value = self.second_value + third_value = self.third_value + + self.model_class.objects.create(test_id=0, class_param=first_value) + result = self._check_value_is_correct_in_db(first_value) + result.delete() + + self.model_class.objects.create(test_id=0, class_param=second_value) + result = self._check_value_is_correct_in_db(second_value) + + result.update(class_param=third_value).save() + result = self._check_value_is_correct_in_db(third_value) + + result.update(class_param=None).save() + self._check_value_is_correct_in_db(None) def test_date_none(self): - self.DateTest.objects.create(test_id=1, created_at=None) - dt2 = self.DateTest.objects(test_id=1).first() - assert dt2.created_at is None + self.model_class.objects.create(test_id=1, class_param=None) + dt2 = self.model_class.objects(test_id=1).first() + self.assertIsNone(dt2.class_param) - dts = self.DateTest.objects(test_id=1).values_list('created_at') - assert dts[0][0] is None + dts = self.model_class.objects(test_id=1).values_list('class_param') + self.assertIsNone(dts[0][0]) + + +class TestDate(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = Date, util.Date + cls.first_value, cls.second_value, cls.third_value = \ + datetime.utcnow(), util.Date(datetime(1, 1, 1)), datetime(1, 1, 2) + super(TestDate, cls).setUpClass() + + +class TestTime(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = Time, util.Time + cls.first_value, cls.second_value, cls.third_value = \ + time(2, 12, 7, 48), util.Time(time(2, 12, 7, 49)), time(2, 12, 7, 50) + super(TestTime, cls).setUpClass() + + +class TestDateTime(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = DateTime, datetime + cls.first_value, cls.second_value, cls.third_value = \ + datetime(2017, 4, 13, 18, 34, 24, 317000), datetime(1, 1, 1), datetime(1, 1, 2) + super(TestDateTime, cls).setUpClass() + + +class TestBoolean(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = Boolean, bool + cls.first_value, cls.second_value, cls.third_value = True, False, True + super(TestBoolean, cls).setUpClass() + + +class User(UserType): + # We use Date and Time to ensure to_python + # is called for these columns + age = Integer() + date_param = Date() + map_param = Map(Integer, Time) + list_param = List(Date) + set_param = Set(Date) + tuple_param = Tuple(Date, Decimal, Boolean, VarInt, Double, UUID) + + +class UserModel(Model): + test_id = Integer(primary_key=True) + class_param = UserDefinedType(User) + + +class TestUDT(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = UserDefinedType, User + cls.first_value = User( + age=1, + date_param=datetime.utcnow(), + map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))}, + list_param=[datetime(1, 1, 2), datetime(1, 1, 3)], + set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 1)))), + tuple_param=(datetime(1, 1, 3), 2, False, 1, 2.324, uuid4()) + ) + + cls.second_value = User( + age=1, + date_param=datetime.utcnow(), + map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))}, + list_param=[datetime(1, 1, 2), datetime(1, 2, 3)], + set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 1)))), + tuple_param=(datetime(1, 1, 2), 2, False, 1, 2.324, uuid4()) + ) + + cls.third_value = User( + age=2, + date_param=datetime.utcnow(), + map_param={1: time(2, 12, 7, 51), 2: util.Time(time(2, 12, 7, 49))}, + list_param=[datetime(1, 1, 2), datetime(1, 1, 4)], + set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 2)))), + tuple_param=(datetime(1, 1, 2), 3, False, 1, 2.3214, uuid4()) + ) + + cls.model_class = UserModel + sync_table(cls.model_class) class TestDecimal(BaseCassEngTestCase): @@ -319,7 +437,7 @@ class TestInteger(BaseCassEngTestCase): class IntegerTest(Model): test_id = UUID(primary_key=True, default=lambda:uuid4()) - value = Integer(default=0, required=True) + value = Integer(default=0, required=True) def test_default_zero_fields_validate(self): """ Tests that integer columns with a default value of 0 validate """ @@ -644,4 +762,3 @@ class TestInet(BaseCassEngTestCase): # TODO: presently this only tests that the server blows it up. Is there supposed to be local validation? with self.assertRaises(InvalidRequest): self.InetTestModel.create(address="what is going on here?") - diff --git a/tests/integration/cqlengine/model/test_udts.py b/tests/integration/cqlengine/model/test_udts.py index fff7001f..e208787e 100644 --- a/tests/integration/cqlengine/model/test_udts.py +++ b/tests/integration/cqlengine/model/test_udts.py @@ -24,11 +24,45 @@ from uuid import UUID, uuid4 from cassandra.cqlengine.models import Model from cassandra.cqlengine.usertype import UserType, UserTypeDefinitionException from cassandra.cqlengine import columns, connection -from cassandra.cqlengine.management import sync_table, sync_type, create_keyspace_simple, drop_keyspace, drop_table +from cassandra.cqlengine.management import sync_table, drop_table, sync_type, create_keyspace_simple, drop_keyspace +from cassandra.cqlengine import ValidationError from cassandra.util import Date, Time from tests.integration import PROTOCOL_VERSION from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine import DEFAULT_KEYSPACE + + +class User(UserType): + age = columns.Integer() + name = columns.Text() + + +class UserModel(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(User) + + +class AllDatatypes(UserType): + a = columns.Ascii() + b = columns.BigInt() + c = columns.Blob() + d = columns.Boolean() + e = columns.DateTime() + f = columns.Decimal() + g = columns.Double() + h = columns.Float() + i = columns.Inet() + j = columns.Integer() + k = columns.Text() + l = columns.TimeUUID() + m = columns.UUID() + n = columns.VarInt() + + +class AllDatatypesModel(Model): + id = columns.Integer(primary_key=True) + data = columns.UserDefinedType(AllDatatypes) class UserDefinedTypeTests(BaseCassEngTestCase): @@ -42,7 +76,7 @@ class UserDefinedTypeTests(BaseCassEngTestCase): age = columns.Integer() name = columns.Text() - sync_type("cqlengine_test", User) + sync_type(DEFAULT_KEYSPACE, User) user = User(age=42, name="John") self.assertEqual(42, user.age) self.assertEqual("John", user.name) @@ -53,8 +87,10 @@ class UserDefinedTypeTests(BaseCassEngTestCase): name = columns.Text() gender = columns.Text() - sync_type("cqlengine_test", User) - user = User(age=42, name="John", gender="male") + sync_type(DEFAULT_KEYSPACE, User) + user = User(age=42) + user["name"] = "John" + user["gender"] = "male" self.assertEqual(42, user.age) self.assertEqual("John", user.name) self.assertEqual("male", user.gender) @@ -64,19 +100,12 @@ class UserDefinedTypeTests(BaseCassEngTestCase): age = columns.Integer() name = columns.Text() - sync_type("cqlengine_test", User) + sync_type(DEFAULT_KEYSPACE, User) user = User(age=42, name="John", gender="male") with self.assertRaises(AttributeError): user.gender def test_can_insert_udts(self): - class User(UserType): - age = columns.Integer() - name = columns.Text() - - class UserModel(Model): - id = columns.Integer(primary_key=True) - info = columns.UserDefinedType(User) sync_table(UserModel) @@ -90,16 +119,9 @@ class UserDefinedTypeTests(BaseCassEngTestCase): self.assertTrue(type(john.info) is User) self.assertEqual(42, john.info.age) self.assertEqual("John", john.info.name) + drop_table(UserModel) def test_can_update_udts(self): - class User(UserType): - age = columns.Integer() - name = columns.Text() - - class UserModel(Model): - id = columns.Integer(primary_key=True) - info = columns.UserDefinedType(User) - sync_table(UserModel) user = User(age=42, name="John") @@ -113,18 +135,11 @@ class UserDefinedTypeTests(BaseCassEngTestCase): created_user.update() mary_info = UserModel.objects().first().info - self.assertEqual(22, mary_info.age) - self.assertEqual("Mary", mary_info.name) + self.assertEqual(22, mary_info["age"]) + self.assertEqual("Mary", mary_info["name"]) + drop_table(UserModel) def test_can_update_udts_with_nones(self): - class User(UserType): - age = columns.Integer() - name = columns.Text() - - class UserModel(Model): - id = columns.Integer(primary_key=True) - info = columns.UserDefinedType(User) - sync_table(UserModel) user = User(age=42, name="John") @@ -139,45 +154,43 @@ class UserDefinedTypeTests(BaseCassEngTestCase): john_info = UserModel.objects().first().info self.assertIsNone(john_info) + drop_table(UserModel) def test_can_create_same_udt_different_keyspaces(self): - class User(UserType): - age = columns.Integer() - name = columns.Text() - - sync_type("cqlengine_test", User) + sync_type(DEFAULT_KEYSPACE, User) create_keyspace_simple("simplex", 1) sync_type("simplex", User) drop_keyspace("simplex") def test_can_insert_partial_udts(self): - class User(UserType): + class UserGender(UserType): age = columns.Integer() name = columns.Text() gender = columns.Text() - class UserModel(Model): + class UserModelGender(Model): id = columns.Integer(primary_key=True) - info = columns.UserDefinedType(User) + info = columns.UserDefinedType(UserGender) - sync_table(UserModel) + sync_table(UserModelGender) - user = User(age=42, name="John") - UserModel.create(id=0, info=user) + user = UserGender(age=42, name="John") + UserModelGender.create(id=0, info=user) - john_info = UserModel.objects().first().info + john_info = UserModelGender.objects().first().info self.assertEqual(42, john_info.age) self.assertEqual("John", john_info.name) self.assertIsNone(john_info.gender) - user = User(age=42) - UserModel.create(id=0, info=user) + user = UserGender(age=42) + UserModelGender.create(id=0, info=user) - john_info = UserModel.objects().first().info + john_info = UserModelGender.objects().first().info self.assertEqual(42, john_info.age) self.assertIsNone(john_info.name) self.assertIsNone(john_info.gender) + drop_table(UserModelGender) def test_can_insert_nested_udts(self): class Depth_0(UserType): @@ -215,6 +228,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase): self.assertEqual(udts[2], output.v_2) self.assertEqual(udts[3], output.v_3) + drop_table(DepthModel) + def test_can_insert_udts_with_nones(self): """ Test for inserting all column types as empty into a UserType as None's @@ -230,27 +245,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase): @test_category data_types:udt """ - - class AllDatatypes(UserType): - a = columns.Ascii() - b = columns.BigInt() - c = columns.Blob() - d = columns.Boolean() - e = columns.DateTime() - f = columns.Decimal() - g = columns.Double() - h = columns.Float() - i = columns.Inet() - j = columns.Integer() - k = columns.Text() - l = columns.TimeUUID() - m = columns.UUID() - n = columns.VarInt() - - class AllDatatypesModel(Model): - id = columns.Integer(primary_key=True) - data = columns.UserDefinedType(AllDatatypes) - sync_table(AllDatatypesModel) input = AllDatatypes(a=None, b=None, c=None, d=None, e=None, f=None, g=None, h=None, i=None, j=None, k=None, @@ -262,6 +256,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase): output = AllDatatypesModel.objects().first().data self.assertEqual(input, output) + drop_table(AllDatatypesModel) + def test_can_insert_udts_with_all_datatypes(self): """ Test for inserting all column types into a UserType @@ -277,27 +273,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase): @test_category data_types:udt """ - - class AllDatatypes(UserType): - a = columns.Ascii() - b = columns.BigInt() - c = columns.Blob() - d = columns.Boolean() - e = columns.DateTime() - f = columns.Decimal() - g = columns.Double() - h = columns.Float() - i = columns.Inet() - j = columns.Integer() - k = columns.Text() - l = columns.TimeUUID() - m = columns.UUID() - n = columns.VarInt() - - class AllDatatypesModel(Model): - id = columns.Integer(primary_key=True) - data = columns.UserDefinedType(AllDatatypes) - sync_table(AllDatatypesModel) input = AllDatatypes(a='ascii', b=2 ** 63 - 1, c=bytearray(b'hello world'), d=True, @@ -313,6 +288,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase): for i in range(ord('a'), ord('a') + 14): self.assertEqual(input[chr(i)], output[chr(i)]) + drop_table(AllDatatypesModel) + def test_can_insert_udts_protocol_v4_datatypes(self): """ Test for inserting all protocol v4 column types into a UserType @@ -354,6 +331,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase): for i in range(ord('a'), ord('a') + 3): self.assertEqual(input[chr(i)], output[chr(i)]) + drop_table(Allv4DatatypesModel) + def test_nested_udts_inserts(self): """ Test for inserting collections of user types using cql engine. @@ -394,6 +373,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase): names_output = Container.objects().first().names self.assertEqual(names_output, names) + drop_table(Container) + def test_udts_with_unicode(self): """ Test for inserting models with unicode and udt columns. @@ -410,10 +391,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase): ascii_name = 'normal name' unicode_name = u'Fran\u00E7ois' - class User(UserType): - age = columns.Integer() - name = columns.Text() - class UserModelText(Model): id = columns.Text(primary_key=True) info = columns.UserDefinedType(User) @@ -427,10 +404,9 @@ class UserDefinedTypeTests(BaseCassEngTestCase): UserModelText.create(id=unicode_name, info=user_template_ascii) UserModelText.create(id=unicode_name, info=user_template_unicode) + drop_table(UserModelText) + def test_register_default_keyspace(self): - class User(UserType): - age = columns.Integer() - name = columns.Text() from cassandra.cqlengine import models from cassandra.cqlengine import connection @@ -495,6 +471,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase): self.assertEqual(info.a, age) self.assertEqual(info.n, name) + drop_table(TheModel) + def test_db_field_overload(self): """ Tests for db_field UserTypeDefinitionException @@ -520,9 +498,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase): def test_set_udt_fields(self): # PYTHON-502 - class User(UserType): - age = columns.Integer() - name = columns.Text() u = User() u.age = 20 @@ -562,3 +537,63 @@ class UserDefinedTypeTests(BaseCassEngTestCase): self.assertEqual(t.nested[0].default_text, "default text") self.assertIsNotNone(t.simple.test_id) self.assertEqual(t.simple.default_text, "default text") + + drop_table(OuterModel) + + def test_udt_validate(self): + """ + Test to verify restrictions are honored and that validate is called + for each member of the UDT when an updated is attempted + + @since 3.10 + @jira_ticket PYTHON-505 + @expected_result a validation error is arisen due to the name being + too long + + @test_category data_types:object_mapper + """ + class UserValidate(UserType): + age = columns.Integer() + name = columns.Text(max_length=2) + + class UserModelValidate(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(UserValidate) + + sync_table(UserModelValidate) + + user = UserValidate(age=1, name="Robert") + item = UserModelValidate(id=1, info=user) + with self.assertRaises(ValidationError): + item.save() + + drop_table(UserModelValidate) + + def test_udt_validate_with_default(self): + """ + Test to verify restrictions are honored and that validate is called + on the default value + + @since 3.10 + @jira_ticket PYTHON-505 + @expected_result a validation error is arisen due to the name being + too long + + @test_category data_types:object_mapper + """ + class UserValidateDefault(UserType): + age = columns.Integer() + name = columns.Text(max_length=2, default="Robert") + + class UserModelValidateDefault(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(UserValidateDefault) + + sync_table(UserModelValidateDefault) + + user = UserValidateDefault(age=1) + item = UserModelValidateDefault(id=1, info=user) + with self.assertRaises(ValidationError): + item.save() + + drop_table(UserModelValidateDefault) diff --git a/tests/integration/cqlengine/model/test_updates.py b/tests/integration/cqlengine/model/test_updates.py index bfab3af5..79c1372a 100644 --- a/tests/integration/cqlengine/model/test_updates.py +++ b/tests/integration/cqlengine/model/test_updates.py @@ -21,7 +21,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase from cassandra.cqlengine.models import Model from cassandra.cqlengine import columns from cassandra.cqlengine.management import sync_table, drop_table - +from cassandra.cqlengine.usertype import UserType class TestUpdateModel(Model): @@ -147,19 +147,40 @@ class ModelUpdateTests(BaseCassEngTestCase): m0.update(partition=uuid4()) +class UDT(UserType): + age = columns.Integer() + mf = columns.Map(columns.Integer, columns.Integer) + dummy_udt = columns.Integer(default=42) + + class ModelWithDefault(Model): - id = columns.Integer(primary_key=True) - mf = columns.Map(columns.Integer, columns.Integer) - dummy = columns.Integer(default=42) + id = columns.Integer(primary_key=True) + mf = columns.Map(columns.Integer, columns.Integer) + dummy = columns.Integer(default=42) + udt = columns.UserDefinedType(UDT) + udt_default = columns.UserDefinedType(UDT, default=UDT(age=1, mf={2:2})) + + +class UDTWithDefault(UserType): + age = columns.Integer() + mf = columns.Map(columns.Integer, columns.Integer, default={2:2}) + dummy_udt = columns.Integer(default=42) class ModelWithDefaultCollection(Model): - id = columns.Integer(primary_key=True) - mf = columns.Map(columns.Integer, columns.Integer, default={2:2}) - dummy = columns.Integer(default=42) + id = columns.Integer(primary_key=True) + mf = columns.Map(columns.Integer, columns.Integer, default={2:2}) + dummy = columns.Integer(default=42) + udt = columns.UserDefinedType(UDT) + udt_default = columns.UserDefinedType(UDT, default=UDT(age=1, mf={2: 2})) class ModelWithDefaultTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + cls.default_udt = UDT(age=1, mf={2:2}, dummy_udt=42) + def setUp(self): sync_table(ModelWithDefault) @@ -176,17 +197,19 @@ class ModelWithDefaultTests(BaseCassEngTestCase): @test_category object_mapper """ - initial = ModelWithDefault(id=1, mf={0: 0}, dummy=0) + first_udt = UDT(age=1, mf={2:2}, dummy_udt=0) + initial = ModelWithDefault(id=1, mf={0: 0}, dummy=0, udt=first_udt, udt_default=first_udt) initial.save() self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(), - {'id': 1, 'dummy': 0, 'mf': {0: 0}}) + {'id': 1, 'dummy': 0, 'mf': {0: 0}, "udt": first_udt, "udt_default": first_udt}) + second_udt = UDT(age=1, mf={3: 3}, dummy_udt=12) second = ModelWithDefault(id=1) - second.update(mf={0: 1}) + second.update(mf={0: 1}, udt=second_udt) self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(), - {'id': 1, 'dummy': 0, 'mf': {0: 1}}) + {'id': 1, 'dummy': 0, 'mf': {0: 1}, "udt": second_udt, "udt_default": first_udt}) def test_value_is_written_if_is_default(self): """ @@ -202,10 +225,11 @@ class ModelWithDefaultTests(BaseCassEngTestCase): initial = ModelWithDefault(id=1) initial.mf = {0: 0} initial.dummy = 42 + initial.udt_default = self.default_udt initial.update() self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(), - {'id': 1, 'dummy': 42, 'mf': {0: 0}}) + {'id': 1, 'dummy': 42, 'mf': {0: 0}, "udt": None, "udt_default": self.default_udt}) def test_null_update_is_respected(self): """ @@ -223,10 +247,11 @@ class ModelWithDefaultTests(BaseCassEngTestCase): q = ModelWithDefault.objects.all().allow_filtering() obj = q.filter(id=1).get() - obj.update(dummy=None) + updated_udt = UDT(age=1, mf={2:2}, dummy_udt=None) + obj.update(dummy=None, udt_default=updated_udt) self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(), - {'id': 1, 'dummy': None, 'mf': {0: 0}}) + {'id': 1, 'dummy': None, 'mf': {0: 0}, "udt": None, "udt_default": updated_udt}) def test_only_set_values_is_updated(self): """ @@ -244,28 +269,32 @@ class ModelWithDefaultTests(BaseCassEngTestCase): item = ModelWithDefault.filter(id=1).first() ModelWithDefault.objects(id=1).delete() item.mf = {1: 2} - + udt, default_udt = UDT(age=1, mf={2:3}), UDT(age=1, mf={2:3}) + item.udt, item.default_udt = udt, default_udt item.save() self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(), - {'id': 1, 'dummy': None, 'mf': {1: 2}}) + {'id': 1, 'dummy': None, 'mf': {1: 2}, "udt": udt, "udt_default": default_udt}) def test_collections(self): """ - Test the updates work as expected when an object is deleted + Test the updates work as expected on Map objects @since 3.9 @jira_ticket PYTHON-657 - @expected_result the non updated column is None and the - updated column has the set value + @expected_result the row is updated when the Map object is + reduced @test_category object_mapper """ - ModelWithDefault.create(id=1, mf={1: 1, 2: 1}, dummy=1).save() + udt, udt_default = UDT(age=1, mf={1: 1, 2: 1}), UDT(age=1, mf={1: 1, 2: 1}) + + ModelWithDefault.create(id=1, mf={1: 1, 2: 1}, dummy=1, udt=udt, udt_default=udt_default).save() item = ModelWithDefault.filter(id=1).first() + udt, udt_default = UDT(age=1, mf={2: 1}), UDT(age=1, mf={2: 1}) item.update(mf={2:1}) self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(), - {'id': 1, 'dummy': 1, 'mf': {2: 1}}) + {'id': 1, 'dummy': 1, 'mf': {2: 1}, "udt": udt, "udt_default": udt_default}) def test_collection_with_default(self): """ @@ -278,24 +307,32 @@ class ModelWithDefaultTests(BaseCassEngTestCase): @test_category object_mapper """ sync_table(ModelWithDefaultCollection) - item = ModelWithDefaultCollection.create(id=1, mf={1: 1}, dummy=1).save() - self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(), - {'id': 1, 'dummy': 1, 'mf': {1: 1}}) - item.update(mf={2: 2}) - self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(), - {'id': 1, 'dummy': 1, 'mf': {2: 2}}) + udt, udt_default = UDT(age=1, mf={6: 6}), UDT(age=1, mf={6: 6}) - item.update(mf=None) + item = ModelWithDefaultCollection.create(id=1, mf={1: 1}, dummy=1, udt=udt, udt_default=udt_default).save() self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(), - {'id': 1, 'dummy': 1, 'mf': {}}) + {'id': 1, 'dummy': 1, 'mf': {1: 1}, "udt": udt, "udt_default": udt_default}) + udt, udt_default = UDT(age=1, mf={5: 5}), UDT(age=1, mf={5: 5}) + item.update(mf={2: 2}, udt=udt, udt_default=udt_default) + self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(), + {'id': 1, 'dummy': 1, 'mf': {2: 2}, "udt": udt, "udt_default": udt_default}) + + udt, udt_default = UDT(age=1, mf=None), UDT(age=1, mf=None) + expected_udt, expected_udt_default = UDT(age=1, mf={}), UDT(age=1, mf={}) + item.update(mf=None, udt=udt, udt_default=udt_default) + self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(), + {'id': 1, 'dummy': 1, 'mf': {}, "udt": expected_udt, "udt_default": expected_udt_default}) + + udt_default = UDT(age=1, mf=None), UDT(age=1, mf={5:5}) item = ModelWithDefaultCollection.create(id=2, dummy=2).save() self.assertEqual(ModelWithDefaultCollection.objects().all().get(id=2)._as_dict(), - {'id': 2, 'dummy': 2, 'mf': {2: 2}}) + {'id': 2, 'dummy': 2, 'mf': {2: 2}, "udt": None, "udt_default": udt_default}) - item.update(mf={1: 1, 4: 4}) + udt, udt_default = UDT(age=1, mf={1: 1, 6: 6}), UDT(age=1, mf={1: 1, 6: 6}) + item.update(mf={1: 1, 4: 4}, udt=udt, udt_default=udt_default) self.assertEqual(ModelWithDefaultCollection.objects().all().get(id=2)._as_dict(), - {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}}) + {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default}) drop_table(ModelWithDefaultCollection) diff --git a/tests/integration/cqlengine/query/test_queryoperators.py b/tests/integration/cqlengine/query/test_queryoperators.py index f52db06f..46d46d65 100644 --- a/tests/integration/cqlengine/query/test_queryoperators.py +++ b/tests/integration/cqlengine/query/test_queryoperators.py @@ -150,7 +150,7 @@ class TestTokenFunction(BaseCassEngTestCase): TokenTestModel.create(key=i, val=i) named = NamedTable(DEFAULT_KEYSPACE, TokenTestModel.__table_name__) - query = named.objects.all().limit(1) + query = named.all().limit(1) first_page = list(query) last = first_page[-1] self.assertTrue(len(first_page) is 1) diff --git a/tests/integration/cqlengine/query/test_updates.py b/tests/integration/cqlengine/query/test_updates.py index 2daa3a48..7c4917be 100644 --- a/tests/integration/cqlengine/query/test_updates.py +++ b/tests/integration/cqlengine/query/test_updates.py @@ -18,21 +18,12 @@ from cassandra.cqlengine import ValidationError from cassandra.cqlengine.models import Model from cassandra.cqlengine.management import sync_table, drop_table from cassandra.cqlengine import columns + from tests.integration.cqlengine import is_prepend_reversed -from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel from tests.integration.cqlengine import execute_count from tests.integration import greaterthancass20 -class TestQueryUpdateModel(Model): - - partition = columns.UUID(primary_key=True, default=uuid4) - cluster = columns.Integer(primary_key=True) - count = columns.Integer(required=False) - text = columns.Text(required=False, index=True) - text_set = columns.Set(columns.Text, required=False) - text_list = columns.List(columns.Text, required=False) - text_map = columns.Map(columns.Text, columns.Text, required=False) - class QueryUpdateTests(BaseCassEngTestCase): diff --git a/tests/integration/cqlengine/statements/test_base_statement.py b/tests/integration/cqlengine/statements/test_base_statement.py index 02936077..21388d0c 100644 --- a/tests/integration/cqlengine/statements/test_base_statement.py +++ b/tests/integration/cqlengine/statements/test_base_statement.py @@ -16,8 +16,19 @@ try: except ImportError: import unittest # noqa +from uuid import uuid4 + from cassandra.query import FETCH_SIZE_UNSET from cassandra.cqlengine.statements import BaseCQLStatement +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.statements import InsertStatement, UpdateStatement, SelectStatement, DeleteStatement, \ + WhereClause +from cassandra.cqlengine.operators import EqualsOperator +from cassandra.cqlengine.columns import Column + +from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel +from tests.integration.cqlengine import DEFAULT_KEYSPACE +from cassandra.cqlengine.connection import execute class BaseStatementTest(unittest.TestCase): @@ -32,3 +43,71 @@ class BaseStatementTest(unittest.TestCase): stmt = BaseCQLStatement('table', None) self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET) + + +class ExecuteStatementTest(BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + super(ExecuteStatementTest, cls).setUpClass() + sync_table(TestQueryUpdateModel) + cls.table_name = '{0}.test_query_update_model'.format(DEFAULT_KEYSPACE) + + @classmethod + def tearDownClass(cls): + super(ExecuteStatementTest, cls).tearDownClass() + drop_table(TestQueryUpdateModel) + + def _verify_statement(self, original): + st = SelectStatement(self.table_name) + result = execute(st) + response = result[0] + + for assignment in original.assignments: + self.assertEqual(response[assignment.field], assignment.value) + self.assertEqual(len(response), 7) + + def test_insert_statement_execute(self): + """ + Test to verify the execution of BaseCQLStatements using connection.execute + + @since 3.10 + @jira_ticket PYTHON-505 + @expected_result inserts a row in C*, updates the rows and then deletes + all the rows using BaseCQLStatements + + @test_category data_types:object_mapper + """ + partition = uuid4() + cluster = 1 + + #Verifying insert statement + st = InsertStatement(self.table_name) + st.add_assignment(Column(db_field='partition'), partition) + st.add_assignment(Column(db_field='cluster'), cluster) + + st.add_assignment(Column(db_field='count'), 1) + st.add_assignment(Column(db_field='text'), "text_for_db") + st.add_assignment(Column(db_field='text_set'), set(("foo", "bar"))) + st.add_assignment(Column(db_field='text_list'), ["foo", "bar"]) + st.add_assignment(Column(db_field='text_map'), {"foo": '1', "bar": '2'}) + + execute(st) + self._verify_statement(st) + + # Verifying update statement + where = [WhereClause('partition', EqualsOperator(), partition), + WhereClause('cluster', EqualsOperator(), cluster)] + + st = UpdateStatement(self.table_name, where=where) + st.add_assignment(Column(db_field='count'), 2) + st.add_assignment(Column(db_field='text'), "text_for_db_update") + st.add_assignment(Column(db_field='text_set'), set(("foo_update", "bar_update"))) + st.add_assignment(Column(db_field='text_list'), ["foo_update", "bar_update"]) + st.add_assignment(Column(db_field='text_map'), {"foo": '3', "bar": '4'}) + + execute(st) + self._verify_statement(st) + + # Verifying delete statement + execute(DeleteStatement(self.table_name, where=where)) + self.assertEqual(TestQueryUpdateModel.objects.count(), 0)