# Copyright 2013-2017 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import uuid import mock from cassandra.cqlengine import columns from cassandra.cqlengine import models from cassandra.cqlengine.connection import get_session from tests.integration.cqlengine.base import BaseCassEngTestCase from cassandra.cqlengine import management class TestInheritanceClassConstruction(BaseCassEngTestCase): def test_multiple_discriminator_value_failure(self): """ Tests that defining a model with more than one discriminator column fails """ with self.assertRaises(models.ModelDefinitionException): class M(models.Model): partition = columns.Integer(primary_key=True) type1 = columns.Integer(discriminator_column=True) type2 = columns.Integer(discriminator_column=True) def test_no_discriminator_column_failure(self): with self.assertRaises(models.ModelDefinitionException): class M(models.Model): __discriminator_value__ = 1 def test_discriminator_value_inheritance(self): """ Tests that discriminator_column attribute is not inherited """ class Base(models.Model): partition = columns.Integer(primary_key=True) type1 = columns.Integer(discriminator_column=True) class M1(Base): __discriminator_value__ = 1 class M2(M1): pass assert M2.__discriminator_value__ is None def test_inheritance_metaclass(self): """ Tests that the model meta class configures inherited models properly """ class Base(models.Model): partition = columns.Integer(primary_key=True) type1 = columns.Integer(discriminator_column=True) class M1(Base): __discriminator_value__ = 1 assert Base._is_polymorphic assert M1._is_polymorphic assert Base._is_polymorphic_base assert not M1._is_polymorphic_base assert Base._discriminator_column is Base._columns['type1'] assert M1._discriminator_column is M1._columns['type1'] assert Base._discriminator_column_name == 'type1' assert M1._discriminator_column_name == 'type1' def test_table_names_are_inherited_from_base(self): class Base(models.Model): partition = columns.Integer(primary_key=True) type1 = columns.Integer(discriminator_column=True) class M1(Base): __discriminator_value__ = 1 assert Base.column_family_name() == M1.column_family_name() def test_collection_columns_cant_be_discriminator_column(self): with self.assertRaises(models.ModelDefinitionException): class Base(models.Model): partition = columns.Integer(primary_key=True) type1 = columns.Set(columns.Integer, discriminator_column=True) class InheritBase(models.Model): partition = columns.UUID(primary_key=True, default=uuid.uuid4) row_type = columns.Integer(discriminator_column=True) class Inherit1(InheritBase): __discriminator_value__ = 1 data1 = columns.Text() class Inherit2(InheritBase): __discriminator_value__ = 2 data2 = columns.Text() class TestInheritanceModel(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestInheritanceModel, cls).setUpClass() management.sync_table(Inherit1) management.sync_table(Inherit2) @classmethod def tearDownClass(cls): super(TestInheritanceModel, cls).tearDownClass() management.drop_table(Inherit1) management.drop_table(Inherit2) def test_saving_base_model_fails(self): with self.assertRaises(models.PolymorphicModelException): InheritBase.create() def test_saving_subclass_saves_disc_value(self): p1 = Inherit1.create(data1='pickle') p2 = Inherit2.create(data2='bacon') assert p1.row_type == Inherit1.__discriminator_value__ assert p2.row_type == Inherit2.__discriminator_value__ def test_query_deserialization(self): p1 = Inherit1.create(data1='pickle') p2 = Inherit2.create(data2='bacon') p1r = InheritBase.get(partition=p1.partition) p2r = InheritBase.get(partition=p2.partition) assert isinstance(p1r, Inherit1) assert isinstance(p2r, Inherit2) def test_delete_on_subclass_does_not_include_disc_value(self): p1 = Inherit1.create() session = get_session() with mock.patch.object(session, 'execute') as m: Inherit1.objects(partition=p1.partition).delete() # make sure our discriminator value isn't in the CQL # not sure how we would even get here if it was in there # since the CQL would fail. self.assertNotIn("row_type", m.call_args[0][0].query_string) class UnindexedInheritBase(models.Model): partition = columns.UUID(primary_key=True, default=uuid.uuid4) cluster = columns.UUID(primary_key=True, default=uuid.uuid4) row_type = columns.Integer(discriminator_column=True) class UnindexedInherit1(UnindexedInheritBase): __discriminator_value__ = 1 data1 = columns.Text() class UnindexedInherit2(UnindexedInheritBase): __discriminator_value__ = 2 data2 = columns.Text() class UnindexedInherit3(UnindexedInherit2): __discriminator_value__ = 3 data3 = columns.Text() class TestUnindexedInheritanceQuery(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestUnindexedInheritanceQuery, cls).setUpClass() management.sync_table(UnindexedInherit1) management.sync_table(UnindexedInherit2) management.sync_table(UnindexedInherit3) cls.p1 = UnindexedInherit1.create(data1='pickle') cls.p2 = UnindexedInherit2.create(partition=cls.p1.partition, data2='bacon') cls.p3 = UnindexedInherit3.create(partition=cls.p1.partition, data3='turkey') @classmethod def tearDownClass(cls): super(TestUnindexedInheritanceQuery, cls).tearDownClass() management.drop_table(UnindexedInherit1) management.drop_table(UnindexedInherit2) management.drop_table(UnindexedInherit3) def test_non_conflicting_type_results_work(self): p1, p2, p3 = self.p1, self.p2, self.p3 assert len(list(UnindexedInherit1.objects(partition=p1.partition, cluster=p1.cluster))) == 1 assert len(list(UnindexedInherit2.objects(partition=p1.partition, cluster=p2.cluster))) == 1 assert len(list(UnindexedInherit3.objects(partition=p1.partition, cluster=p3.cluster))) == 1 def test_subclassed_model_results_work_properly(self): p1, p2, p3 = self.p1, self.p2, self.p3 assert len(list(UnindexedInherit2.objects(partition=p1.partition, cluster__in=[p2.cluster, p3.cluster]))) == 2 def test_conflicting_type_results(self): with self.assertRaises(models.PolymorphicModelException): list(UnindexedInherit1.objects(partition=self.p1.partition)) with self.assertRaises(models.PolymorphicModelException): list(UnindexedInherit2.objects(partition=self.p1.partition)) class IndexedInheritBase(models.Model): partition = columns.UUID(primary_key=True, default=uuid.uuid4) cluster = columns.UUID(primary_key=True, default=uuid.uuid4) row_type = columns.Integer(discriminator_column=True, index=True) class IndexedInherit1(IndexedInheritBase): __discriminator_value__ = 1 data1 = columns.Text() class IndexedInherit2(IndexedInheritBase): __discriminator_value__ = 2 data2 = columns.Text() class TestIndexedInheritanceQuery(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestIndexedInheritanceQuery, cls).setUpClass() management.sync_table(IndexedInherit1) management.sync_table(IndexedInherit2) cls.p1 = IndexedInherit1.create(data1='pickle') cls.p2 = IndexedInherit2.create(partition=cls.p1.partition, data2='bacon') @classmethod def tearDownClass(cls): super(TestIndexedInheritanceQuery, cls).tearDownClass() management.drop_table(IndexedInherit1) management.drop_table(IndexedInherit2) def test_success_case(self): self.assertEqual(len(list(IndexedInherit1.objects(partition=self.p1.partition))), 1) self.assertEqual(len(list(IndexedInherit2.objects(partition=self.p1.partition))), 1)