Refactored/Added some tests to CQLEngine
This commit is contained in:
		@@ -337,7 +337,6 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connect
 | 
				
			|||||||
        query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
 | 
					        query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
 | 
				
			||||||
    elif isinstance(query, six.string_types):
 | 
					    elif isinstance(query, six.string_types):
 | 
				
			||||||
        query = SimpleStatement(query, consistency_level=consistency_level)
 | 
					        query = SimpleStatement(query, consistency_level=consistency_level)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    log.debug(format_log_context(query.query_string, connection=connection))
 | 
					    log.debug(format_log_context(query.query_string, connection=connection))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = conn.session.execute(query, params, timeout=timeout)
 | 
					    result = conn.session.execute(query, params, timeout=timeout)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,7 +21,7 @@ except ImportError:
 | 
				
			|||||||
from cassandra import ConsistencyLevel
 | 
					from cassandra import ConsistencyLevel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from cassandra.cqlengine import connection
 | 
					from cassandra.cqlengine import connection
 | 
				
			||||||
from cassandra.cqlengine.management import create_keyspace_simple, CQLENG_ALLOW_SCHEMA_MANAGEMENT
 | 
					from cassandra.cqlengine.management import create_keyspace_simple, drop_keyspace, CQLENG_ALLOW_SCHEMA_MANAGEMENT
 | 
				
			||||||
import cassandra
 | 
					import cassandra
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from tests.integration import get_server_versions, use_single_node, PROTOCOL_VERSION, CASSANDRA_IP, set_default_cass_ip
 | 
					from tests.integration import get_server_versions, use_single_node, PROTOCOL_VERSION, CASSANDRA_IP, set_default_cass_ip
 | 
				
			||||||
@@ -45,7 +45,6 @@ def setup_package():
 | 
				
			|||||||
def teardown_package():
 | 
					def teardown_package():
 | 
				
			||||||
    connection.unregister_connection("default")
 | 
					    connection.unregister_connection("default")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
def is_prepend_reversed():
 | 
					def is_prepend_reversed():
 | 
				
			||||||
    # do we have https://issues.apache.org/jira/browse/CASSANDRA-8733 ?
 | 
					    # do we have https://issues.apache.org/jira/browse/CASSANDRA-8733 ?
 | 
				
			||||||
    ver, _ = get_server_versions()
 | 
					    ver, _ = get_server_versions()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,7 +19,20 @@ except ImportError:
 | 
				
			|||||||
import sys
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from cassandra.cqlengine.connection import get_session
 | 
					from cassandra.cqlengine.connection import get_session
 | 
				
			||||||
 | 
					from cassandra.cqlengine.models import Model
 | 
				
			||||||
 | 
					from cassandra.cqlengine import columns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from uuid import uuid4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestQueryUpdateModel(Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    partition = columns.UUID(primary_key=True, default=uuid4)
 | 
				
			||||||
 | 
					    cluster = columns.Integer(primary_key=True)
 | 
				
			||||||
 | 
					    count = columns.Integer(required=False)
 | 
				
			||||||
 | 
					    text = columns.Text(required=False, index=True)
 | 
				
			||||||
 | 
					    text_set = columns.Set(columns.Text, required=False)
 | 
				
			||||||
 | 
					    text_list = columns.List(columns.Text, required=False)
 | 
				
			||||||
 | 
					    text_map = columns.Map(columns.Text, columns.Text, required=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseCassEngTestCase(unittest.TestCase):
 | 
					class BaseCassEngTestCase(unittest.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,30 +18,21 @@ except ImportError:
 | 
				
			|||||||
    import unittest  # noqa
 | 
					    import unittest  # noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
from datetime import datetime, timedelta, date, tzinfo
 | 
					from datetime import datetime, timedelta, date, tzinfo, time
 | 
				
			||||||
from decimal import Decimal as D
 | 
					from decimal import Decimal as D
 | 
				
			||||||
from uuid import uuid4, uuid1
 | 
					from uuid import uuid4, uuid1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from cassandra import InvalidRequest
 | 
					from cassandra import InvalidRequest
 | 
				
			||||||
from cassandra.cqlengine.columns import TimeUUID
 | 
					from cassandra.cqlengine.columns import TimeUUID, Ascii, Text, Integer, BigInt, VarInt, DateTime, Date, UUID, Boolean, \
 | 
				
			||||||
from cassandra.cqlengine.columns import Ascii
 | 
					    Decimal, Inet, Time, UserDefinedType, Map, List, Set, Tuple, Double, Float
 | 
				
			||||||
from cassandra.cqlengine.columns import Text
 | 
					 | 
				
			||||||
from cassandra.cqlengine.columns import Integer
 | 
					 | 
				
			||||||
from cassandra.cqlengine.columns import BigInt
 | 
					 | 
				
			||||||
from cassandra.cqlengine.columns import VarInt
 | 
					 | 
				
			||||||
from cassandra.cqlengine.columns import DateTime
 | 
					 | 
				
			||||||
from cassandra.cqlengine.columns import Date
 | 
					 | 
				
			||||||
from cassandra.cqlengine.columns import UUID
 | 
					 | 
				
			||||||
from cassandra.cqlengine.columns import Boolean
 | 
					 | 
				
			||||||
from cassandra.cqlengine.columns import Decimal
 | 
					 | 
				
			||||||
from cassandra.cqlengine.columns import Inet
 | 
					 | 
				
			||||||
from cassandra.cqlengine.connection import execute
 | 
					from cassandra.cqlengine.connection import execute
 | 
				
			||||||
from cassandra.cqlengine.management import sync_table, drop_table
 | 
					from cassandra.cqlengine.management import sync_table, drop_table
 | 
				
			||||||
from cassandra.cqlengine.models import Model, ValidationError
 | 
					from cassandra.cqlengine.models import Model, ValidationError
 | 
				
			||||||
 | 
					from cassandra.cqlengine.usertype import UserType
 | 
				
			||||||
from cassandra import util
 | 
					from cassandra import util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from tests.integration import PROTOCOL_VERSION
 | 
					from tests.integration import PROTOCOL_VERSION
 | 
				
			||||||
from tests.integration.cqlengine.base import BaseCassEngTestCase
 | 
					from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestDatetime(BaseCassEngTestCase):
 | 
					class TestDatetime(BaseCassEngTestCase):
 | 
				
			||||||
@@ -62,7 +53,7 @@ class TestDatetime(BaseCassEngTestCase):
 | 
				
			|||||||
        now = datetime.now()
 | 
					        now = datetime.now()
 | 
				
			||||||
        self.DatetimeTest.objects.create(test_id=0, created_at=now)
 | 
					        self.DatetimeTest.objects.create(test_id=0, created_at=now)
 | 
				
			||||||
        dt2 = self.DatetimeTest.objects(test_id=0).first()
 | 
					        dt2 = self.DatetimeTest.objects(test_id=0).first()
 | 
				
			||||||
        assert dt2.created_at.timetuple()[:6] == now.timetuple()[:6]
 | 
					        self.assertEqual(dt2.created_at.timetuple()[:6], now.timetuple()[:6])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_datetime_tzinfo_io(self):
 | 
					    def test_datetime_tzinfo_io(self):
 | 
				
			||||||
        class TZ(tzinfo):
 | 
					        class TZ(tzinfo):
 | 
				
			||||||
@@ -74,21 +65,27 @@ class TestDatetime(BaseCassEngTestCase):
 | 
				
			|||||||
        now = datetime(1982, 1, 1, tzinfo=TZ())
 | 
					        now = datetime(1982, 1, 1, tzinfo=TZ())
 | 
				
			||||||
        dt = self.DatetimeTest.objects.create(test_id=1, created_at=now)
 | 
					        dt = self.DatetimeTest.objects.create(test_id=1, created_at=now)
 | 
				
			||||||
        dt2 = self.DatetimeTest.objects(test_id=1).first()
 | 
					        dt2 = self.DatetimeTest.objects(test_id=1).first()
 | 
				
			||||||
        assert dt2.created_at.timetuple()[:6] == (now + timedelta(hours=1)).timetuple()[:6]
 | 
					        self.assertEqual(dt2.created_at.timetuple()[:6], (now + timedelta(hours=1)).timetuple()[:6])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_datetime_date_support(self):
 | 
					    def test_datetime_date_support(self):
 | 
				
			||||||
        today = date.today()
 | 
					        today = date.today()
 | 
				
			||||||
        self.DatetimeTest.objects.create(test_id=2, created_at=today)
 | 
					        self.DatetimeTest.objects.create(test_id=2, created_at=today)
 | 
				
			||||||
        dt2 = self.DatetimeTest.objects(test_id=2).first()
 | 
					        dt2 = self.DatetimeTest.objects(test_id=2).first()
 | 
				
			||||||
        assert dt2.created_at.isoformat() == datetime(today.year, today.month, today.day).isoformat()
 | 
					        self.assertEqual(dt2.created_at.isoformat(), datetime(today.year, today.month, today.day).isoformat())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2).first()
 | 
				
			||||||
 | 
					        self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2, created_at=today).first()
 | 
				
			||||||
 | 
					        self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_datetime_none(self):
 | 
					    def test_datetime_none(self):
 | 
				
			||||||
        dt = self.DatetimeTest.objects.create(test_id=3, created_at=None)
 | 
					        dt = self.DatetimeTest.objects.create(test_id=3, created_at=None)
 | 
				
			||||||
        dt2 = self.DatetimeTest.objects(test_id=3).first()
 | 
					        dt2 = self.DatetimeTest.objects(test_id=3).first()
 | 
				
			||||||
        assert dt2.created_at is None
 | 
					        self.assertIsNone(dt2.created_at)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dts = self.DatetimeTest.objects.filter(test_id=3).values_list('created_at')
 | 
					        dts = self.DatetimeTest.objects.filter(test_id=3).values_list('created_at')
 | 
				
			||||||
        assert dts[0][0] is None
 | 
					        self.assertIsNone(dts[0][0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_datetime_invalid(self):
 | 
					    def test_datetime_invalid(self):
 | 
				
			||||||
        dt_value= 'INVALID'
 | 
					        dt_value= 'INVALID'
 | 
				
			||||||
@@ -99,13 +96,13 @@ class TestDatetime(BaseCassEngTestCase):
 | 
				
			|||||||
        dt_value = 1454520554
 | 
					        dt_value = 1454520554
 | 
				
			||||||
        self.DatetimeTest.objects.create(test_id=5, created_at=dt_value)
 | 
					        self.DatetimeTest.objects.create(test_id=5, created_at=dt_value)
 | 
				
			||||||
        dt2 = self.DatetimeTest.objects(test_id=5).first()
 | 
					        dt2 = self.DatetimeTest.objects(test_id=5).first()
 | 
				
			||||||
        assert dt2.created_at == datetime.utcfromtimestamp(dt_value)
 | 
					        self.assertEqual(dt2.created_at, datetime.utcfromtimestamp(dt_value))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_datetime_large(self):
 | 
					    def test_datetime_large(self):
 | 
				
			||||||
        dt_value = datetime(2038, 12, 31, 10, 10, 10, 123000)
 | 
					        dt_value = datetime(2038, 12, 31, 10, 10, 10, 123000)
 | 
				
			||||||
        self.DatetimeTest.objects.create(test_id=6, created_at=dt_value)
 | 
					        self.DatetimeTest.objects.create(test_id=6, created_at=dt_value)
 | 
				
			||||||
        dt2 = self.DatetimeTest.objects(test_id=6).first()
 | 
					        dt2 = self.DatetimeTest.objects(test_id=6).first()
 | 
				
			||||||
        assert dt2.created_at == dt_value
 | 
					        self.assertEqual(dt2.created_at, dt_value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_datetime_truncate_microseconds(self):
 | 
					    def test_datetime_truncate_microseconds(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -187,49 +184,170 @@ class TestVarInt(BaseCassEngTestCase):
 | 
				
			|||||||
        int2 = self.VarIntTest.objects(test_id=0).first()
 | 
					        int2 = self.VarIntTest.objects(test_id=0).first()
 | 
				
			||||||
        self.assertEqual(int1.bignum, int2.bignum)
 | 
					        self.assertEqual(int1.bignum, int2.bignum)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with self.assertRaises(ValidationError):
 | 
				
			||||||
 | 
					            self.VarIntTest.objects.create(test_id=0, bignum="not_a_number")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestDate(BaseCassEngTestCase):
 | 
					 | 
				
			||||||
    class DateTest(Model):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        test_id = Integer(primary_key=True)
 | 
					 | 
				
			||||||
        created_at = Date()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DataType():
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def setUpClass(cls):
 | 
					    def setUpClass(cls):
 | 
				
			||||||
        if PROTOCOL_VERSION < 4:
 | 
					        if PROTOCOL_VERSION < 4:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        sync_table(cls.DateTest)
 | 
					
 | 
				
			||||||
 | 
					        class DataTypeTest(Model):
 | 
				
			||||||
 | 
					            test_id = Integer(primary_key=True)
 | 
				
			||||||
 | 
					            class_param = cls.db_klass()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cls.model_class = DataTypeTest
 | 
				
			||||||
 | 
					        sync_table(cls.model_class)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def tearDownClass(cls):
 | 
					    def tearDownClass(cls):
 | 
				
			||||||
        if PROTOCOL_VERSION < 4:
 | 
					        if PROTOCOL_VERSION < 4:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        drop_table(cls.DateTest)
 | 
					        drop_table(cls.model_class)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        if PROTOCOL_VERSION < 4:
 | 
					        if PROTOCOL_VERSION < 4:
 | 
				
			||||||
            raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION))
 | 
					            raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_date_io(self):
 | 
					    def _check_value_is_correct_in_db(self, value):
 | 
				
			||||||
        today = date.today()
 | 
					
 | 
				
			||||||
        self.DateTest.objects.create(test_id=0, created_at=today)
 | 
					        if value is None:
 | 
				
			||||||
        result = self.DateTest.objects(test_id=0).first()
 | 
					            result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first()
 | 
				
			||||||
        self.assertEqual(result.created_at, util.Date(today))
 | 
					            self.assertIsNone(result.class_param)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if not isinstance(value, self.python_klass):
 | 
				
			||||||
 | 
					                value_to_compare = self.python_klass(value)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                value_to_compare = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            result = self.model_class.objects(test_id=0).first()
 | 
				
			||||||
 | 
					            self.assertIsInstance(result.class_param, self.python_klass)
 | 
				
			||||||
 | 
					            self.assertEqual(result.class_param, value_to_compare)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first()
 | 
				
			||||||
 | 
					            self.assertIsInstance(result.class_param, self.python_klass)
 | 
				
			||||||
 | 
					            self.assertEqual(result.class_param, value_to_compare)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            result = self.model_class.objects.all().allow_filtering().filter(test_id=0, class_param=value).first()
 | 
				
			||||||
 | 
					            self.assertIsInstance(result.class_param, self.python_klass)
 | 
				
			||||||
 | 
					            self.assertEqual(result.class_param, value_to_compare)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_date_io_using_datetime(self):
 | 
					    def test_date_io_using_datetime(self):
 | 
				
			||||||
        now = datetime.utcnow()
 | 
					        first_value = self.first_value
 | 
				
			||||||
        self.DateTest.objects.create(test_id=0, created_at=now)
 | 
					        second_value = self.second_value
 | 
				
			||||||
        result = self.DateTest.objects(test_id=0).first()
 | 
					        third_value = self.third_value
 | 
				
			||||||
        self.assertIsInstance(result.created_at, util.Date)
 | 
					
 | 
				
			||||||
        self.assertEqual(result.created_at, util.Date(now))
 | 
					        self.model_class.objects.create(test_id=0, class_param=first_value)
 | 
				
			||||||
 | 
					        result = self._check_value_is_correct_in_db(first_value)
 | 
				
			||||||
 | 
					        result.delete()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.model_class.objects.create(test_id=0, class_param=second_value)
 | 
				
			||||||
 | 
					        result = self._check_value_is_correct_in_db(second_value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        result.update(class_param=third_value).save()
 | 
				
			||||||
 | 
					        result = self._check_value_is_correct_in_db(third_value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        result.update(class_param=None).save()
 | 
				
			||||||
 | 
					        self._check_value_is_correct_in_db(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_date_none(self):
 | 
					    def test_date_none(self):
 | 
				
			||||||
        self.DateTest.objects.create(test_id=1, created_at=None)
 | 
					        self.model_class.objects.create(test_id=1, class_param=None)
 | 
				
			||||||
        dt2 = self.DateTest.objects(test_id=1).first()
 | 
					        dt2 = self.model_class.objects(test_id=1).first()
 | 
				
			||||||
        assert dt2.created_at is None
 | 
					        self.assertIsNone(dt2.class_param)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dts = self.DateTest.objects(test_id=1).values_list('created_at')
 | 
					        dts = self.model_class.objects(test_id=1).values_list('class_param')
 | 
				
			||||||
        assert dts[0][0] is None
 | 
					        self.assertIsNone(dts[0][0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestDate(DataType, BaseCassEngTestCase):
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def setUpClass(cls):
 | 
				
			||||||
 | 
					        cls.db_klass, cls.python_klass = Date, util.Date
 | 
				
			||||||
 | 
					        cls.first_value, cls.second_value, cls.third_value = \
 | 
				
			||||||
 | 
					            datetime.utcnow(), util.Date(datetime(1, 1, 1)), datetime(1, 1, 2)
 | 
				
			||||||
 | 
					        super(TestDate, cls).setUpClass()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestTime(DataType, BaseCassEngTestCase):
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def setUpClass(cls):
 | 
				
			||||||
 | 
					        cls.db_klass, cls.python_klass = Time, util.Time
 | 
				
			||||||
 | 
					        cls.first_value, cls.second_value, cls.third_value = \
 | 
				
			||||||
 | 
					            time(2, 12, 7, 48),  util.Time(time(2, 12, 7, 49)), time(2, 12, 7, 50)
 | 
				
			||||||
 | 
					        super(TestTime, cls).setUpClass()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestDateTime(DataType, BaseCassEngTestCase):
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def setUpClass(cls):
 | 
				
			||||||
 | 
					        cls.db_klass, cls.python_klass = DateTime, datetime
 | 
				
			||||||
 | 
					        cls.first_value, cls.second_value, cls.third_value = \
 | 
				
			||||||
 | 
					            datetime(2017, 4, 13, 18, 34, 24, 317000), datetime(1, 1, 1), datetime(1, 1, 2)
 | 
				
			||||||
 | 
					        super(TestDateTime, cls).setUpClass()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestBoolean(DataType, BaseCassEngTestCase):
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def setUpClass(cls):
 | 
				
			||||||
 | 
					        cls.db_klass, cls.python_klass = Boolean, bool
 | 
				
			||||||
 | 
					        cls.first_value, cls.second_value, cls.third_value = True, False, True
 | 
				
			||||||
 | 
					        super(TestBoolean, cls).setUpClass()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class User(UserType):
 | 
				
			||||||
 | 
					    # We use Date and Time to ensure to_python
 | 
				
			||||||
 | 
					    # is called for these columns
 | 
				
			||||||
 | 
					    age = Integer()
 | 
				
			||||||
 | 
					    date_param = Date()
 | 
				
			||||||
 | 
					    map_param = Map(Integer, Time)
 | 
				
			||||||
 | 
					    list_param = List(Date)
 | 
				
			||||||
 | 
					    set_param = Set(Date)
 | 
				
			||||||
 | 
					    tuple_param = Tuple(Date, Decimal, Boolean, VarInt, Double, UUID)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UserModel(Model):
 | 
				
			||||||
 | 
					    test_id = Integer(primary_key=True)
 | 
				
			||||||
 | 
					    class_param = UserDefinedType(User)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestUDT(DataType, BaseCassEngTestCase):
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def setUpClass(cls):
 | 
				
			||||||
 | 
					        cls.db_klass, cls.python_klass = UserDefinedType, User
 | 
				
			||||||
 | 
					        cls.first_value = User(
 | 
				
			||||||
 | 
					            age=1,
 | 
				
			||||||
 | 
					            date_param=datetime.utcnow(),
 | 
				
			||||||
 | 
					            map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))},
 | 
				
			||||||
 | 
					            list_param=[datetime(1, 1, 2), datetime(1, 1, 3)],
 | 
				
			||||||
 | 
					            set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 1)))),
 | 
				
			||||||
 | 
					            tuple_param=(datetime(1, 1, 3), 2, False, 1, 2.324, uuid4())
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cls.second_value = User(
 | 
				
			||||||
 | 
					            age=1,
 | 
				
			||||||
 | 
					            date_param=datetime.utcnow(),
 | 
				
			||||||
 | 
					            map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))},
 | 
				
			||||||
 | 
					            list_param=[datetime(1, 1, 2), datetime(1, 2, 3)],
 | 
				
			||||||
 | 
					            set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 1)))),
 | 
				
			||||||
 | 
					            tuple_param=(datetime(1, 1, 2), 2, False, 1, 2.324, uuid4())
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cls.third_value = User(
 | 
				
			||||||
 | 
					            age=2,
 | 
				
			||||||
 | 
					            date_param=datetime.utcnow(),
 | 
				
			||||||
 | 
					            map_param={1: time(2, 12, 7, 51), 2: util.Time(time(2, 12, 7, 49))},
 | 
				
			||||||
 | 
					            list_param=[datetime(1, 1, 2), datetime(1, 1, 4)],
 | 
				
			||||||
 | 
					            set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 2)))),
 | 
				
			||||||
 | 
					            tuple_param=(datetime(1, 1, 2), 3, False, 1, 2.3214, uuid4())
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cls.model_class = UserModel
 | 
				
			||||||
 | 
					        sync_table(cls.model_class)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestDecimal(BaseCassEngTestCase):
 | 
					class TestDecimal(BaseCassEngTestCase):
 | 
				
			||||||
@@ -319,7 +437,7 @@ class TestInteger(BaseCassEngTestCase):
 | 
				
			|||||||
    class IntegerTest(Model):
 | 
					    class IntegerTest(Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        test_id = UUID(primary_key=True, default=lambda:uuid4())
 | 
					        test_id = UUID(primary_key=True, default=lambda:uuid4())
 | 
				
			||||||
        value   = Integer(default=0, required=True)
 | 
					        value = Integer(default=0, required=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_default_zero_fields_validate(self):
 | 
					    def test_default_zero_fields_validate(self):
 | 
				
			||||||
        """ Tests that integer columns with a default value of 0 validate """
 | 
					        """ Tests that integer columns with a default value of 0 validate """
 | 
				
			||||||
@@ -644,4 +762,3 @@ class TestInet(BaseCassEngTestCase):
 | 
				
			|||||||
        # TODO: presently this only tests that the server blows it up. Is there supposed to be local validation?
 | 
					        # TODO: presently this only tests that the server blows it up. Is there supposed to be local validation?
 | 
				
			||||||
        with self.assertRaises(InvalidRequest):
 | 
					        with self.assertRaises(InvalidRequest):
 | 
				
			||||||
            self.InetTestModel.create(address="what is going on here?")
 | 
					            self.InetTestModel.create(address="what is going on here?")
 | 
				
			||||||
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -24,11 +24,45 @@ from uuid import UUID, uuid4
 | 
				
			|||||||
from cassandra.cqlengine.models import Model
 | 
					from cassandra.cqlengine.models import Model
 | 
				
			||||||
from cassandra.cqlengine.usertype import UserType, UserTypeDefinitionException
 | 
					from cassandra.cqlengine.usertype import UserType, UserTypeDefinitionException
 | 
				
			||||||
from cassandra.cqlengine import columns, connection
 | 
					from cassandra.cqlengine import columns, connection
 | 
				
			||||||
from cassandra.cqlengine.management import sync_table, sync_type, create_keyspace_simple, drop_keyspace, drop_table
 | 
					from cassandra.cqlengine.management import sync_table, drop_table, sync_type, create_keyspace_simple, drop_keyspace
 | 
				
			||||||
 | 
					from cassandra.cqlengine import ValidationError
 | 
				
			||||||
from cassandra.util import Date, Time
 | 
					from cassandra.util import Date, Time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from tests.integration import PROTOCOL_VERSION
 | 
					from tests.integration import PROTOCOL_VERSION
 | 
				
			||||||
from tests.integration.cqlengine.base import BaseCassEngTestCase
 | 
					from tests.integration.cqlengine.base import BaseCassEngTestCase
 | 
				
			||||||
 | 
					from tests.integration.cqlengine import DEFAULT_KEYSPACE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class User(UserType):
 | 
				
			||||||
 | 
					    age = columns.Integer()
 | 
				
			||||||
 | 
					    name = columns.Text()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UserModel(Model):
 | 
				
			||||||
 | 
					    id = columns.Integer(primary_key=True)
 | 
				
			||||||
 | 
					    info = columns.UserDefinedType(User)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AllDatatypes(UserType):
 | 
				
			||||||
 | 
					    a = columns.Ascii()
 | 
				
			||||||
 | 
					    b = columns.BigInt()
 | 
				
			||||||
 | 
					    c = columns.Blob()
 | 
				
			||||||
 | 
					    d = columns.Boolean()
 | 
				
			||||||
 | 
					    e = columns.DateTime()
 | 
				
			||||||
 | 
					    f = columns.Decimal()
 | 
				
			||||||
 | 
					    g = columns.Double()
 | 
				
			||||||
 | 
					    h = columns.Float()
 | 
				
			||||||
 | 
					    i = columns.Inet()
 | 
				
			||||||
 | 
					    j = columns.Integer()
 | 
				
			||||||
 | 
					    k = columns.Text()
 | 
				
			||||||
 | 
					    l = columns.TimeUUID()
 | 
				
			||||||
 | 
					    m = columns.UUID()
 | 
				
			||||||
 | 
					    n = columns.VarInt()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AllDatatypesModel(Model):
 | 
				
			||||||
 | 
					    id = columns.Integer(primary_key=True)
 | 
				
			||||||
 | 
					    data = columns.UserDefinedType(AllDatatypes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
					class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			||||||
@@ -42,7 +76,7 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
            age = columns.Integer()
 | 
					            age = columns.Integer()
 | 
				
			||||||
            name = columns.Text()
 | 
					            name = columns.Text()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sync_type("cqlengine_test", User)
 | 
					        sync_type(DEFAULT_KEYSPACE, User)
 | 
				
			||||||
        user = User(age=42, name="John")
 | 
					        user = User(age=42, name="John")
 | 
				
			||||||
        self.assertEqual(42, user.age)
 | 
					        self.assertEqual(42, user.age)
 | 
				
			||||||
        self.assertEqual("John", user.name)
 | 
					        self.assertEqual("John", user.name)
 | 
				
			||||||
@@ -53,8 +87,10 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
            name = columns.Text()
 | 
					            name = columns.Text()
 | 
				
			||||||
            gender = columns.Text()
 | 
					            gender = columns.Text()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sync_type("cqlengine_test", User)
 | 
					        sync_type(DEFAULT_KEYSPACE, User)
 | 
				
			||||||
        user = User(age=42, name="John", gender="male")
 | 
					        user = User(age=42)
 | 
				
			||||||
 | 
					        user["name"] = "John"
 | 
				
			||||||
 | 
					        user["gender"] = "male"
 | 
				
			||||||
        self.assertEqual(42, user.age)
 | 
					        self.assertEqual(42, user.age)
 | 
				
			||||||
        self.assertEqual("John", user.name)
 | 
					        self.assertEqual("John", user.name)
 | 
				
			||||||
        self.assertEqual("male", user.gender)
 | 
					        self.assertEqual("male", user.gender)
 | 
				
			||||||
@@ -64,19 +100,12 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
            age = columns.Integer()
 | 
					            age = columns.Integer()
 | 
				
			||||||
            name = columns.Text()
 | 
					            name = columns.Text()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sync_type("cqlengine_test", User)
 | 
					        sync_type(DEFAULT_KEYSPACE, User)
 | 
				
			||||||
        user = User(age=42, name="John", gender="male")
 | 
					        user = User(age=42, name="John", gender="male")
 | 
				
			||||||
        with self.assertRaises(AttributeError):
 | 
					        with self.assertRaises(AttributeError):
 | 
				
			||||||
            user.gender
 | 
					            user.gender
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_can_insert_udts(self):
 | 
					    def test_can_insert_udts(self):
 | 
				
			||||||
        class User(UserType):
 | 
					 | 
				
			||||||
            age = columns.Integer()
 | 
					 | 
				
			||||||
            name = columns.Text()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        class UserModel(Model):
 | 
					 | 
				
			||||||
            id = columns.Integer(primary_key=True)
 | 
					 | 
				
			||||||
            info = columns.UserDefinedType(User)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sync_table(UserModel)
 | 
					        sync_table(UserModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -90,16 +119,9 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        self.assertTrue(type(john.info) is User)
 | 
					        self.assertTrue(type(john.info) is User)
 | 
				
			||||||
        self.assertEqual(42, john.info.age)
 | 
					        self.assertEqual(42, john.info.age)
 | 
				
			||||||
        self.assertEqual("John", john.info.name)
 | 
					        self.assertEqual("John", john.info.name)
 | 
				
			||||||
 | 
					        drop_table(UserModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_can_update_udts(self):
 | 
					    def test_can_update_udts(self):
 | 
				
			||||||
        class User(UserType):
 | 
					 | 
				
			||||||
            age = columns.Integer()
 | 
					 | 
				
			||||||
            name = columns.Text()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        class UserModel(Model):
 | 
					 | 
				
			||||||
            id = columns.Integer(primary_key=True)
 | 
					 | 
				
			||||||
            info = columns.UserDefinedType(User)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        sync_table(UserModel)
 | 
					        sync_table(UserModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        user = User(age=42, name="John")
 | 
					        user = User(age=42, name="John")
 | 
				
			||||||
@@ -113,18 +135,11 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        created_user.update()
 | 
					        created_user.update()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        mary_info = UserModel.objects().first().info
 | 
					        mary_info = UserModel.objects().first().info
 | 
				
			||||||
        self.assertEqual(22, mary_info.age)
 | 
					        self.assertEqual(22, mary_info["age"])
 | 
				
			||||||
        self.assertEqual("Mary", mary_info.name)
 | 
					        self.assertEqual("Mary", mary_info["name"])
 | 
				
			||||||
 | 
					        drop_table(UserModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_can_update_udts_with_nones(self):
 | 
					    def test_can_update_udts_with_nones(self):
 | 
				
			||||||
        class User(UserType):
 | 
					 | 
				
			||||||
            age = columns.Integer()
 | 
					 | 
				
			||||||
            name = columns.Text()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        class UserModel(Model):
 | 
					 | 
				
			||||||
            id = columns.Integer(primary_key=True)
 | 
					 | 
				
			||||||
            info = columns.UserDefinedType(User)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        sync_table(UserModel)
 | 
					        sync_table(UserModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        user = User(age=42, name="John")
 | 
					        user = User(age=42, name="John")
 | 
				
			||||||
@@ -139,45 +154,43 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        john_info = UserModel.objects().first().info
 | 
					        john_info = UserModel.objects().first().info
 | 
				
			||||||
        self.assertIsNone(john_info)
 | 
					        self.assertIsNone(john_info)
 | 
				
			||||||
 | 
					        drop_table(UserModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_can_create_same_udt_different_keyspaces(self):
 | 
					    def test_can_create_same_udt_different_keyspaces(self):
 | 
				
			||||||
        class User(UserType):
 | 
					        sync_type(DEFAULT_KEYSPACE, User)
 | 
				
			||||||
            age = columns.Integer()
 | 
					 | 
				
			||||||
            name = columns.Text()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        sync_type("cqlengine_test", User)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        create_keyspace_simple("simplex", 1)
 | 
					        create_keyspace_simple("simplex", 1)
 | 
				
			||||||
        sync_type("simplex", User)
 | 
					        sync_type("simplex", User)
 | 
				
			||||||
        drop_keyspace("simplex")
 | 
					        drop_keyspace("simplex")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_can_insert_partial_udts(self):
 | 
					    def test_can_insert_partial_udts(self):
 | 
				
			||||||
        class User(UserType):
 | 
					        class UserGender(UserType):
 | 
				
			||||||
            age = columns.Integer()
 | 
					            age = columns.Integer()
 | 
				
			||||||
            name = columns.Text()
 | 
					            name = columns.Text()
 | 
				
			||||||
            gender = columns.Text()
 | 
					            gender = columns.Text()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        class UserModel(Model):
 | 
					        class UserModelGender(Model):
 | 
				
			||||||
            id = columns.Integer(primary_key=True)
 | 
					            id = columns.Integer(primary_key=True)
 | 
				
			||||||
            info = columns.UserDefinedType(User)
 | 
					            info = columns.UserDefinedType(UserGender)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sync_table(UserModel)
 | 
					        sync_table(UserModelGender)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        user = User(age=42, name="John")
 | 
					        user = UserGender(age=42, name="John")
 | 
				
			||||||
        UserModel.create(id=0, info=user)
 | 
					        UserModelGender.create(id=0, info=user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        john_info = UserModel.objects().first().info
 | 
					        john_info = UserModelGender.objects().first().info
 | 
				
			||||||
        self.assertEqual(42, john_info.age)
 | 
					        self.assertEqual(42, john_info.age)
 | 
				
			||||||
        self.assertEqual("John", john_info.name)
 | 
					        self.assertEqual("John", john_info.name)
 | 
				
			||||||
        self.assertIsNone(john_info.gender)
 | 
					        self.assertIsNone(john_info.gender)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        user = User(age=42)
 | 
					        user = UserGender(age=42)
 | 
				
			||||||
        UserModel.create(id=0, info=user)
 | 
					        UserModelGender.create(id=0, info=user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        john_info = UserModel.objects().first().info
 | 
					        john_info = UserModelGender.objects().first().info
 | 
				
			||||||
        self.assertEqual(42, john_info.age)
 | 
					        self.assertEqual(42, john_info.age)
 | 
				
			||||||
        self.assertIsNone(john_info.name)
 | 
					        self.assertIsNone(john_info.name)
 | 
				
			||||||
        self.assertIsNone(john_info.gender)
 | 
					        self.assertIsNone(john_info.gender)
 | 
				
			||||||
 | 
					        drop_table(UserModelGender)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_can_insert_nested_udts(self):
 | 
					    def test_can_insert_nested_udts(self):
 | 
				
			||||||
        class Depth_0(UserType):
 | 
					        class Depth_0(UserType):
 | 
				
			||||||
@@ -215,6 +228,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        self.assertEqual(udts[2], output.v_2)
 | 
					        self.assertEqual(udts[2], output.v_2)
 | 
				
			||||||
        self.assertEqual(udts[3], output.v_3)
 | 
					        self.assertEqual(udts[3], output.v_3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        drop_table(DepthModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_can_insert_udts_with_nones(self):
 | 
					    def test_can_insert_udts_with_nones(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Test for inserting all column types as empty into a UserType as None's
 | 
					        Test for inserting all column types as empty into a UserType as None's
 | 
				
			||||||
@@ -230,27 +245,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        @test_category data_types:udt
 | 
					        @test_category data_types:udt
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					 | 
				
			||||||
        class AllDatatypes(UserType):
 | 
					 | 
				
			||||||
            a = columns.Ascii()
 | 
					 | 
				
			||||||
            b = columns.BigInt()
 | 
					 | 
				
			||||||
            c = columns.Blob()
 | 
					 | 
				
			||||||
            d = columns.Boolean()
 | 
					 | 
				
			||||||
            e = columns.DateTime()
 | 
					 | 
				
			||||||
            f = columns.Decimal()
 | 
					 | 
				
			||||||
            g = columns.Double()
 | 
					 | 
				
			||||||
            h = columns.Float()
 | 
					 | 
				
			||||||
            i = columns.Inet()
 | 
					 | 
				
			||||||
            j = columns.Integer()
 | 
					 | 
				
			||||||
            k = columns.Text()
 | 
					 | 
				
			||||||
            l = columns.TimeUUID()
 | 
					 | 
				
			||||||
            m = columns.UUID()
 | 
					 | 
				
			||||||
            n = columns.VarInt()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        class AllDatatypesModel(Model):
 | 
					 | 
				
			||||||
            id = columns.Integer(primary_key=True)
 | 
					 | 
				
			||||||
            data = columns.UserDefinedType(AllDatatypes)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        sync_table(AllDatatypesModel)
 | 
					        sync_table(AllDatatypesModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        input = AllDatatypes(a=None, b=None, c=None, d=None, e=None, f=None, g=None, h=None, i=None, j=None, k=None,
 | 
					        input = AllDatatypes(a=None, b=None, c=None, d=None, e=None, f=None, g=None, h=None, i=None, j=None, k=None,
 | 
				
			||||||
@@ -262,6 +256,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        output = AllDatatypesModel.objects().first().data
 | 
					        output = AllDatatypesModel.objects().first().data
 | 
				
			||||||
        self.assertEqual(input, output)
 | 
					        self.assertEqual(input, output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        drop_table(AllDatatypesModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_can_insert_udts_with_all_datatypes(self):
 | 
					    def test_can_insert_udts_with_all_datatypes(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Test for inserting all column types into a UserType
 | 
					        Test for inserting all column types into a UserType
 | 
				
			||||||
@@ -277,27 +273,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        @test_category data_types:udt
 | 
					        @test_category data_types:udt
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					 | 
				
			||||||
        class AllDatatypes(UserType):
 | 
					 | 
				
			||||||
            a = columns.Ascii()
 | 
					 | 
				
			||||||
            b = columns.BigInt()
 | 
					 | 
				
			||||||
            c = columns.Blob()
 | 
					 | 
				
			||||||
            d = columns.Boolean()
 | 
					 | 
				
			||||||
            e = columns.DateTime()
 | 
					 | 
				
			||||||
            f = columns.Decimal()
 | 
					 | 
				
			||||||
            g = columns.Double()
 | 
					 | 
				
			||||||
            h = columns.Float()
 | 
					 | 
				
			||||||
            i = columns.Inet()
 | 
					 | 
				
			||||||
            j = columns.Integer()
 | 
					 | 
				
			||||||
            k = columns.Text()
 | 
					 | 
				
			||||||
            l = columns.TimeUUID()
 | 
					 | 
				
			||||||
            m = columns.UUID()
 | 
					 | 
				
			||||||
            n = columns.VarInt()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        class AllDatatypesModel(Model):
 | 
					 | 
				
			||||||
            id = columns.Integer(primary_key=True)
 | 
					 | 
				
			||||||
            data = columns.UserDefinedType(AllDatatypes)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        sync_table(AllDatatypesModel)
 | 
					        sync_table(AllDatatypesModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        input = AllDatatypes(a='ascii', b=2 ** 63 - 1, c=bytearray(b'hello world'), d=True,
 | 
					        input = AllDatatypes(a='ascii', b=2 ** 63 - 1, c=bytearray(b'hello world'), d=True,
 | 
				
			||||||
@@ -313,6 +288,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        for i in range(ord('a'), ord('a') + 14):
 | 
					        for i in range(ord('a'), ord('a') + 14):
 | 
				
			||||||
            self.assertEqual(input[chr(i)], output[chr(i)])
 | 
					            self.assertEqual(input[chr(i)], output[chr(i)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        drop_table(AllDatatypesModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_can_insert_udts_protocol_v4_datatypes(self):
 | 
					    def test_can_insert_udts_protocol_v4_datatypes(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Test for inserting all protocol v4 column types into a UserType
 | 
					        Test for inserting all protocol v4 column types into a UserType
 | 
				
			||||||
@@ -354,6 +331,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        for i in range(ord('a'), ord('a') + 3):
 | 
					        for i in range(ord('a'), ord('a') + 3):
 | 
				
			||||||
            self.assertEqual(input[chr(i)], output[chr(i)])
 | 
					            self.assertEqual(input[chr(i)], output[chr(i)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        drop_table(Allv4DatatypesModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_nested_udts_inserts(self):
 | 
					    def test_nested_udts_inserts(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Test for inserting collections of user types using cql engine.
 | 
					        Test for inserting collections of user types using cql engine.
 | 
				
			||||||
@@ -394,6 +373,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        names_output = Container.objects().first().names
 | 
					        names_output = Container.objects().first().names
 | 
				
			||||||
        self.assertEqual(names_output, names)
 | 
					        self.assertEqual(names_output, names)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        drop_table(Container)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_udts_with_unicode(self):
 | 
					    def test_udts_with_unicode(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Test for inserting models with unicode and udt columns.
 | 
					        Test for inserting models with unicode and udt columns.
 | 
				
			||||||
@@ -410,10 +391,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        ascii_name = 'normal name'
 | 
					        ascii_name = 'normal name'
 | 
				
			||||||
        unicode_name = u'Fran\u00E7ois'
 | 
					        unicode_name = u'Fran\u00E7ois'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        class User(UserType):
 | 
					 | 
				
			||||||
            age = columns.Integer()
 | 
					 | 
				
			||||||
            name = columns.Text()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        class UserModelText(Model):
 | 
					        class UserModelText(Model):
 | 
				
			||||||
            id = columns.Text(primary_key=True)
 | 
					            id = columns.Text(primary_key=True)
 | 
				
			||||||
            info = columns.UserDefinedType(User)
 | 
					            info = columns.UserDefinedType(User)
 | 
				
			||||||
@@ -427,10 +404,9 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        UserModelText.create(id=unicode_name, info=user_template_ascii)
 | 
					        UserModelText.create(id=unicode_name, info=user_template_ascii)
 | 
				
			||||||
        UserModelText.create(id=unicode_name, info=user_template_unicode)
 | 
					        UserModelText.create(id=unicode_name, info=user_template_unicode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        drop_table(UserModelText)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_register_default_keyspace(self):
 | 
					    def test_register_default_keyspace(self):
 | 
				
			||||||
        class User(UserType):
 | 
					 | 
				
			||||||
            age = columns.Integer()
 | 
					 | 
				
			||||||
            name = columns.Text()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from cassandra.cqlengine import models
 | 
					        from cassandra.cqlengine import models
 | 
				
			||||||
        from cassandra.cqlengine import connection
 | 
					        from cassandra.cqlengine import connection
 | 
				
			||||||
@@ -495,6 +471,8 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        self.assertEqual(info.a, age)
 | 
					        self.assertEqual(info.a, age)
 | 
				
			||||||
        self.assertEqual(info.n, name)
 | 
					        self.assertEqual(info.n, name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        drop_table(TheModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_db_field_overload(self):
 | 
					    def test_db_field_overload(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Tests for db_field UserTypeDefinitionException
 | 
					        Tests for db_field UserTypeDefinitionException
 | 
				
			||||||
@@ -520,9 +498,6 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def test_set_udt_fields(self):
 | 
					    def test_set_udt_fields(self):
 | 
				
			||||||
        # PYTHON-502
 | 
					        # PYTHON-502
 | 
				
			||||||
        class User(UserType):
 | 
					 | 
				
			||||||
            age = columns.Integer()
 | 
					 | 
				
			||||||
            name = columns.Text()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        u = User()
 | 
					        u = User()
 | 
				
			||||||
        u.age = 20
 | 
					        u.age = 20
 | 
				
			||||||
@@ -562,3 +537,63 @@ class UserDefinedTypeTests(BaseCassEngTestCase):
 | 
				
			|||||||
        self.assertEqual(t.nested[0].default_text, "default text")
 | 
					        self.assertEqual(t.nested[0].default_text, "default text")
 | 
				
			||||||
        self.assertIsNotNone(t.simple.test_id)
 | 
					        self.assertIsNotNone(t.simple.test_id)
 | 
				
			||||||
        self.assertEqual(t.simple.default_text, "default text")
 | 
					        self.assertEqual(t.simple.default_text, "default text")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        drop_table(OuterModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_udt_validate(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Test to verify restrictions are honored and that validate is called
 | 
				
			||||||
 | 
					        for each member of the UDT when an updated is attempted
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @since 3.10
 | 
				
			||||||
 | 
					        @jira_ticket PYTHON-505
 | 
				
			||||||
 | 
					        @expected_result a validation error is arisen due to the name being
 | 
				
			||||||
 | 
					        too long
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @test_category data_types:object_mapper
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        class UserValidate(UserType):
 | 
				
			||||||
 | 
					            age = columns.Integer()
 | 
				
			||||||
 | 
					            name = columns.Text(max_length=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class UserModelValidate(Model):
 | 
				
			||||||
 | 
					            id = columns.Integer(primary_key=True)
 | 
				
			||||||
 | 
					            info = columns.UserDefinedType(UserValidate)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sync_table(UserModelValidate)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        user = UserValidate(age=1, name="Robert")
 | 
				
			||||||
 | 
					        item = UserModelValidate(id=1, info=user)
 | 
				
			||||||
 | 
					        with self.assertRaises(ValidationError):
 | 
				
			||||||
 | 
					            item.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        drop_table(UserModelValidate)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_udt_validate_with_default(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Test to verify restrictions are honored and that validate is called
 | 
				
			||||||
 | 
					        on the default value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @since 3.10
 | 
				
			||||||
 | 
					        @jira_ticket PYTHON-505
 | 
				
			||||||
 | 
					        @expected_result a validation error is arisen due to the name being
 | 
				
			||||||
 | 
					        too long
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @test_category data_types:object_mapper
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        class UserValidateDefault(UserType):
 | 
				
			||||||
 | 
					            age = columns.Integer()
 | 
				
			||||||
 | 
					            name = columns.Text(max_length=2, default="Robert")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class UserModelValidateDefault(Model):
 | 
				
			||||||
 | 
					            id = columns.Integer(primary_key=True)
 | 
				
			||||||
 | 
					            info = columns.UserDefinedType(UserValidateDefault)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sync_table(UserModelValidateDefault)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        user = UserValidateDefault(age=1)
 | 
				
			||||||
 | 
					        item = UserModelValidateDefault(id=1, info=user)
 | 
				
			||||||
 | 
					        with self.assertRaises(ValidationError):
 | 
				
			||||||
 | 
					            item.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        drop_table(UserModelValidateDefault)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,7 +21,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase
 | 
				
			|||||||
from cassandra.cqlengine.models import Model
 | 
					from cassandra.cqlengine.models import Model
 | 
				
			||||||
from cassandra.cqlengine import columns
 | 
					from cassandra.cqlengine import columns
 | 
				
			||||||
from cassandra.cqlengine.management import sync_table, drop_table
 | 
					from cassandra.cqlengine.management import sync_table, drop_table
 | 
				
			||||||
 | 
					from cassandra.cqlengine.usertype import UserType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestUpdateModel(Model):
 | 
					class TestUpdateModel(Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -147,19 +147,40 @@ class ModelUpdateTests(BaseCassEngTestCase):
 | 
				
			|||||||
            m0.update(partition=uuid4())
 | 
					            m0.update(partition=uuid4())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UDT(UserType):
 | 
				
			||||||
 | 
					    age = columns.Integer()
 | 
				
			||||||
 | 
					    mf = columns.Map(columns.Integer, columns.Integer)
 | 
				
			||||||
 | 
					    dummy_udt = columns.Integer(default=42)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ModelWithDefault(Model):
 | 
					class ModelWithDefault(Model):
 | 
				
			||||||
    id          = columns.Integer(primary_key=True)
 | 
					    id = columns.Integer(primary_key=True)
 | 
				
			||||||
    mf          = columns.Map(columns.Integer, columns.Integer)
 | 
					    mf = columns.Map(columns.Integer, columns.Integer)
 | 
				
			||||||
    dummy       = columns.Integer(default=42)
 | 
					    dummy = columns.Integer(default=42)
 | 
				
			||||||
 | 
					    udt = columns.UserDefinedType(UDT)
 | 
				
			||||||
 | 
					    udt_default = columns.UserDefinedType(UDT, default=UDT(age=1, mf={2:2}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UDTWithDefault(UserType):
 | 
				
			||||||
 | 
					    age = columns.Integer()
 | 
				
			||||||
 | 
					    mf = columns.Map(columns.Integer, columns.Integer, default={2:2})
 | 
				
			||||||
 | 
					    dummy_udt = columns.Integer(default=42)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ModelWithDefaultCollection(Model):
 | 
					class ModelWithDefaultCollection(Model):
 | 
				
			||||||
    id          = columns.Integer(primary_key=True)
 | 
					    id = columns.Integer(primary_key=True)
 | 
				
			||||||
    mf          = columns.Map(columns.Integer, columns.Integer, default={2:2})
 | 
					    mf = columns.Map(columns.Integer, columns.Integer, default={2:2})
 | 
				
			||||||
    dummy       = columns.Integer(default=42)
 | 
					    dummy = columns.Integer(default=42)
 | 
				
			||||||
 | 
					    udt = columns.UserDefinedType(UDT)
 | 
				
			||||||
 | 
					    udt_default = columns.UserDefinedType(UDT, default=UDT(age=1, mf={2: 2}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ModelWithDefaultTests(BaseCassEngTestCase):
 | 
					class ModelWithDefaultTests(BaseCassEngTestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def setUpClass(cls):
 | 
				
			||||||
 | 
					        cls.default_udt = UDT(age=1, mf={2:2}, dummy_udt=42)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        sync_table(ModelWithDefault)
 | 
					        sync_table(ModelWithDefault)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -176,17 +197,19 @@ class ModelWithDefaultTests(BaseCassEngTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        @test_category object_mapper
 | 
					        @test_category object_mapper
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        initial = ModelWithDefault(id=1, mf={0: 0}, dummy=0)
 | 
					        first_udt = UDT(age=1, mf={2:2}, dummy_udt=0)
 | 
				
			||||||
 | 
					        initial = ModelWithDefault(id=1, mf={0: 0}, dummy=0, udt=first_udt, udt_default=first_udt)
 | 
				
			||||||
        initial.save()
 | 
					        initial.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
					        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
				
			||||||
                         {'id': 1, 'dummy': 0, 'mf': {0: 0}})
 | 
					                         {'id': 1, 'dummy': 0, 'mf': {0: 0}, "udt": first_udt, "udt_default": first_udt})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        second_udt = UDT(age=1, mf={3: 3}, dummy_udt=12)
 | 
				
			||||||
        second = ModelWithDefault(id=1)
 | 
					        second = ModelWithDefault(id=1)
 | 
				
			||||||
        second.update(mf={0: 1})
 | 
					        second.update(mf={0: 1}, udt=second_udt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
					        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
				
			||||||
                         {'id': 1, 'dummy': 0, 'mf': {0: 1}})
 | 
					                         {'id': 1, 'dummy': 0, 'mf': {0: 1}, "udt": second_udt, "udt_default": first_udt})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_value_is_written_if_is_default(self):
 | 
					    def test_value_is_written_if_is_default(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -202,10 +225,11 @@ class ModelWithDefaultTests(BaseCassEngTestCase):
 | 
				
			|||||||
        initial = ModelWithDefault(id=1)
 | 
					        initial = ModelWithDefault(id=1)
 | 
				
			||||||
        initial.mf = {0: 0}
 | 
					        initial.mf = {0: 0}
 | 
				
			||||||
        initial.dummy = 42
 | 
					        initial.dummy = 42
 | 
				
			||||||
 | 
					        initial.udt_default = self.default_udt
 | 
				
			||||||
        initial.update()
 | 
					        initial.update()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
					        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
				
			||||||
                         {'id': 1, 'dummy': 42, 'mf': {0: 0}})
 | 
					                         {'id': 1, 'dummy': 42, 'mf': {0: 0}, "udt": None, "udt_default": self.default_udt})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_null_update_is_respected(self):
 | 
					    def test_null_update_is_respected(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -223,10 +247,11 @@ class ModelWithDefaultTests(BaseCassEngTestCase):
 | 
				
			|||||||
        q = ModelWithDefault.objects.all().allow_filtering()
 | 
					        q = ModelWithDefault.objects.all().allow_filtering()
 | 
				
			||||||
        obj = q.filter(id=1).get()
 | 
					        obj = q.filter(id=1).get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        obj.update(dummy=None)
 | 
					        updated_udt = UDT(age=1, mf={2:2}, dummy_udt=None)
 | 
				
			||||||
 | 
					        obj.update(dummy=None, udt_default=updated_udt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
					        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
				
			||||||
                         {'id': 1, 'dummy': None, 'mf': {0: 0}})
 | 
					                         {'id': 1, 'dummy': None, 'mf': {0: 0}, "udt": None, "udt_default": updated_udt})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_only_set_values_is_updated(self):
 | 
					    def test_only_set_values_is_updated(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -244,28 +269,32 @@ class ModelWithDefaultTests(BaseCassEngTestCase):
 | 
				
			|||||||
        item = ModelWithDefault.filter(id=1).first()
 | 
					        item = ModelWithDefault.filter(id=1).first()
 | 
				
			||||||
        ModelWithDefault.objects(id=1).delete()
 | 
					        ModelWithDefault.objects(id=1).delete()
 | 
				
			||||||
        item.mf = {1: 2}
 | 
					        item.mf = {1: 2}
 | 
				
			||||||
 | 
					        udt, default_udt = UDT(age=1, mf={2:3}), UDT(age=1, mf={2:3})
 | 
				
			||||||
 | 
					        item.udt, item.default_udt = udt, default_udt
 | 
				
			||||||
        item.save()
 | 
					        item.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
					        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
				
			||||||
                         {'id': 1, 'dummy': None, 'mf': {1: 2}})
 | 
					                         {'id': 1, 'dummy': None, 'mf': {1: 2}, "udt": udt, "udt_default": default_udt})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_collections(self):
 | 
					    def test_collections(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Test the updates work as expected when an object is deleted
 | 
					        Test the updates work as expected on Map objects
 | 
				
			||||||
        @since 3.9
 | 
					        @since 3.9
 | 
				
			||||||
        @jira_ticket PYTHON-657
 | 
					        @jira_ticket PYTHON-657
 | 
				
			||||||
        @expected_result the non updated column is None and the
 | 
					        @expected_result the row is updated when the Map object is
 | 
				
			||||||
        updated column has the set value
 | 
					        reduced
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @test_category object_mapper
 | 
					        @test_category object_mapper
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        ModelWithDefault.create(id=1, mf={1: 1, 2: 1}, dummy=1).save()
 | 
					        udt, udt_default = UDT(age=1, mf={1: 1, 2: 1}), UDT(age=1, mf={1: 1, 2: 1})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ModelWithDefault.create(id=1, mf={1: 1, 2: 1}, dummy=1, udt=udt, udt_default=udt_default).save()
 | 
				
			||||||
        item = ModelWithDefault.filter(id=1).first()
 | 
					        item = ModelWithDefault.filter(id=1).first()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        udt, udt_default = UDT(age=1, mf={2: 1}), UDT(age=1, mf={2: 1})
 | 
				
			||||||
        item.update(mf={2:1})
 | 
					        item.update(mf={2:1})
 | 
				
			||||||
        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
					        self.assertEqual(ModelWithDefault.objects().all().get()._as_dict(),
 | 
				
			||||||
                         {'id': 1, 'dummy': 1, 'mf': {2: 1}})
 | 
					                         {'id': 1, 'dummy': 1, 'mf': {2: 1}, "udt": udt, "udt_default": udt_default})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_collection_with_default(self):
 | 
					    def test_collection_with_default(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -278,24 +307,32 @@ class ModelWithDefaultTests(BaseCassEngTestCase):
 | 
				
			|||||||
        @test_category object_mapper
 | 
					        @test_category object_mapper
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        sync_table(ModelWithDefaultCollection)
 | 
					        sync_table(ModelWithDefaultCollection)
 | 
				
			||||||
        item = ModelWithDefaultCollection.create(id=1, mf={1: 1}, dummy=1).save()
 | 
					 | 
				
			||||||
        self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
 | 
					 | 
				
			||||||
                         {'id': 1, 'dummy': 1, 'mf': {1: 1}})
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        item.update(mf={2: 2})
 | 
					        udt, udt_default = UDT(age=1, mf={6: 6}), UDT(age=1, mf={6: 6})
 | 
				
			||||||
        self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
 | 
					 | 
				
			||||||
                         {'id': 1, 'dummy': 1, 'mf': {2: 2}})
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        item.update(mf=None)
 | 
					        item = ModelWithDefaultCollection.create(id=1, mf={1: 1}, dummy=1, udt=udt, udt_default=udt_default).save()
 | 
				
			||||||
        self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
 | 
					        self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
 | 
				
			||||||
                         {'id': 1, 'dummy': 1, 'mf': {}})
 | 
					                         {'id': 1, 'dummy': 1, 'mf': {1: 1}, "udt": udt, "udt_default": udt_default})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        udt, udt_default = UDT(age=1, mf={5: 5}), UDT(age=1, mf={5: 5})
 | 
				
			||||||
 | 
					        item.update(mf={2: 2}, udt=udt, udt_default=udt_default)
 | 
				
			||||||
 | 
					        self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
 | 
				
			||||||
 | 
					                         {'id': 1, 'dummy': 1, 'mf': {2: 2}, "udt": udt, "udt_default": udt_default})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        udt, udt_default = UDT(age=1, mf=None), UDT(age=1, mf=None)
 | 
				
			||||||
 | 
					        expected_udt, expected_udt_default = UDT(age=1, mf={}), UDT(age=1, mf={})
 | 
				
			||||||
 | 
					        item.update(mf=None, udt=udt, udt_default=udt_default)
 | 
				
			||||||
 | 
					        self.assertEqual(ModelWithDefaultCollection.objects().all().get()._as_dict(),
 | 
				
			||||||
 | 
					                         {'id': 1, 'dummy': 1, 'mf': {}, "udt": expected_udt, "udt_default": expected_udt_default})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        udt_default = UDT(age=1, mf=None), UDT(age=1, mf={5:5})
 | 
				
			||||||
        item = ModelWithDefaultCollection.create(id=2, dummy=2).save()
 | 
					        item = ModelWithDefaultCollection.create(id=2, dummy=2).save()
 | 
				
			||||||
        self.assertEqual(ModelWithDefaultCollection.objects().all().get(id=2)._as_dict(),
 | 
					        self.assertEqual(ModelWithDefaultCollection.objects().all().get(id=2)._as_dict(),
 | 
				
			||||||
                         {'id': 2, 'dummy': 2, 'mf': {2: 2}})
 | 
					                         {'id': 2, 'dummy': 2, 'mf': {2: 2}, "udt": None, "udt_default": udt_default})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        item.update(mf={1: 1, 4: 4})
 | 
					        udt, udt_default = UDT(age=1, mf={1: 1, 6: 6}), UDT(age=1, mf={1: 1, 6: 6})
 | 
				
			||||||
 | 
					        item.update(mf={1: 1, 4: 4}, udt=udt, udt_default=udt_default)
 | 
				
			||||||
        self.assertEqual(ModelWithDefaultCollection.objects().all().get(id=2)._as_dict(),
 | 
					        self.assertEqual(ModelWithDefaultCollection.objects().all().get(id=2)._as_dict(),
 | 
				
			||||||
                         {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}})
 | 
					                         {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        drop_table(ModelWithDefaultCollection)
 | 
					        drop_table(ModelWithDefaultCollection)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -150,7 +150,7 @@ class TestTokenFunction(BaseCassEngTestCase):
 | 
				
			|||||||
            TokenTestModel.create(key=i, val=i)
 | 
					            TokenTestModel.create(key=i, val=i)
 | 
				
			||||||
        named = NamedTable(DEFAULT_KEYSPACE, TokenTestModel.__table_name__)
 | 
					        named = NamedTable(DEFAULT_KEYSPACE, TokenTestModel.__table_name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        query = named.objects.all().limit(1)
 | 
					        query = named.all().limit(1)
 | 
				
			||||||
        first_page = list(query)
 | 
					        first_page = list(query)
 | 
				
			||||||
        last = first_page[-1]
 | 
					        last = first_page[-1]
 | 
				
			||||||
        self.assertTrue(len(first_page) is 1)
 | 
					        self.assertTrue(len(first_page) is 1)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,21 +18,12 @@ from cassandra.cqlengine import ValidationError
 | 
				
			|||||||
from cassandra.cqlengine.models import Model
 | 
					from cassandra.cqlengine.models import Model
 | 
				
			||||||
from cassandra.cqlengine.management import sync_table, drop_table
 | 
					from cassandra.cqlengine.management import sync_table, drop_table
 | 
				
			||||||
from cassandra.cqlengine import columns
 | 
					from cassandra.cqlengine import columns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from tests.integration.cqlengine import is_prepend_reversed
 | 
					from tests.integration.cqlengine import is_prepend_reversed
 | 
				
			||||||
from tests.integration.cqlengine.base import BaseCassEngTestCase
 | 
					from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel
 | 
				
			||||||
from tests.integration.cqlengine import execute_count
 | 
					from tests.integration.cqlengine import execute_count
 | 
				
			||||||
from tests.integration import greaterthancass20
 | 
					from tests.integration import greaterthancass20
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestQueryUpdateModel(Model):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    partition = columns.UUID(primary_key=True, default=uuid4)
 | 
					 | 
				
			||||||
    cluster = columns.Integer(primary_key=True)
 | 
					 | 
				
			||||||
    count = columns.Integer(required=False)
 | 
					 | 
				
			||||||
    text = columns.Text(required=False, index=True)
 | 
					 | 
				
			||||||
    text_set = columns.Set(columns.Text, required=False)
 | 
					 | 
				
			||||||
    text_list = columns.List(columns.Text, required=False)
 | 
					 | 
				
			||||||
    text_map = columns.Map(columns.Text, columns.Text, required=False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class QueryUpdateTests(BaseCassEngTestCase):
 | 
					class QueryUpdateTests(BaseCassEngTestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -16,8 +16,19 @@ try:
 | 
				
			|||||||
except ImportError:
 | 
					except ImportError:
 | 
				
			||||||
    import unittest  # noqa
 | 
					    import unittest  # noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from uuid import uuid4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from cassandra.query import FETCH_SIZE_UNSET
 | 
					from cassandra.query import FETCH_SIZE_UNSET
 | 
				
			||||||
from cassandra.cqlengine.statements import BaseCQLStatement
 | 
					from cassandra.cqlengine.statements import BaseCQLStatement
 | 
				
			||||||
 | 
					from cassandra.cqlengine.management import sync_table, drop_table
 | 
				
			||||||
 | 
					from cassandra.cqlengine.statements import InsertStatement, UpdateStatement, SelectStatement, DeleteStatement, \
 | 
				
			||||||
 | 
					    WhereClause
 | 
				
			||||||
 | 
					from cassandra.cqlengine.operators import EqualsOperator
 | 
				
			||||||
 | 
					from cassandra.cqlengine.columns import Column
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel
 | 
				
			||||||
 | 
					from tests.integration.cqlengine import DEFAULT_KEYSPACE
 | 
				
			||||||
 | 
					from cassandra.cqlengine.connection import execute
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseStatementTest(unittest.TestCase):
 | 
					class BaseStatementTest(unittest.TestCase):
 | 
				
			||||||
@@ -32,3 +43,71 @@ class BaseStatementTest(unittest.TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        stmt = BaseCQLStatement('table', None)
 | 
					        stmt = BaseCQLStatement('table', None)
 | 
				
			||||||
        self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET)
 | 
					        self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ExecuteStatementTest(BaseCassEngTestCase):
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def setUpClass(cls):
 | 
				
			||||||
 | 
					        super(ExecuteStatementTest, cls).setUpClass()
 | 
				
			||||||
 | 
					        sync_table(TestQueryUpdateModel)
 | 
				
			||||||
 | 
					        cls.table_name = '{0}.test_query_update_model'.format(DEFAULT_KEYSPACE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def tearDownClass(cls):
 | 
				
			||||||
 | 
					        super(ExecuteStatementTest, cls).tearDownClass()
 | 
				
			||||||
 | 
					        drop_table(TestQueryUpdateModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _verify_statement(self, original):
 | 
				
			||||||
 | 
					        st = SelectStatement(self.table_name)
 | 
				
			||||||
 | 
					        result = execute(st)
 | 
				
			||||||
 | 
					        response = result[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for assignment in original.assignments:
 | 
				
			||||||
 | 
					            self.assertEqual(response[assignment.field], assignment.value)
 | 
				
			||||||
 | 
					        self.assertEqual(len(response), 7)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_insert_statement_execute(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Test to verify the execution of BaseCQLStatements using connection.execute
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @since 3.10
 | 
				
			||||||
 | 
					        @jira_ticket PYTHON-505
 | 
				
			||||||
 | 
					        @expected_result inserts a row in C*, updates the rows and then deletes
 | 
				
			||||||
 | 
					        all the rows using BaseCQLStatements
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @test_category data_types:object_mapper
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        partition = uuid4()
 | 
				
			||||||
 | 
					        cluster = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #Verifying insert statement
 | 
				
			||||||
 | 
					        st = InsertStatement(self.table_name)
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='partition'), partition)
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='cluster'), cluster)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='count'), 1)
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='text'), "text_for_db")
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='text_set'), set(("foo", "bar")))
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='text_list'), ["foo", "bar"])
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='text_map'), {"foo": '1', "bar": '2'})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        execute(st)
 | 
				
			||||||
 | 
					        self._verify_statement(st)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Verifying update statement
 | 
				
			||||||
 | 
					        where = [WhereClause('partition', EqualsOperator(), partition),
 | 
				
			||||||
 | 
					                 WhereClause('cluster', EqualsOperator(), cluster)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        st = UpdateStatement(self.table_name, where=where)
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='count'), 2)
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='text'), "text_for_db_update")
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='text_set'), set(("foo_update", "bar_update")))
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='text_list'), ["foo_update", "bar_update"])
 | 
				
			||||||
 | 
					        st.add_assignment(Column(db_field='text_map'), {"foo": '3', "bar": '4'})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        execute(st)
 | 
				
			||||||
 | 
					        self._verify_statement(st)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Verifying delete statement
 | 
				
			||||||
 | 
					        execute(DeleteStatement(self.table_name, where=where))
 | 
				
			||||||
 | 
					        self.assertEqual(TestQueryUpdateModel.objects.count(), 0)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user