diff --git a/cassandra/cqlengine/connection.py b/cassandra/cqlengine/connection.py index 3d3cd89c..3070e67e 100644 --- a/cassandra/cqlengine/connection.py +++ b/cassandra/cqlengine/connection.py @@ -146,7 +146,7 @@ def unregister_connection(name): if name not in _connections: return - if _connections[name] == _connections[DEFAULT_CONNECTION]: + if DEFAULT_CONNECTION in _connections and _connections[name] == _connections[DEFAULT_CONNECTION]: del _connections[DEFAULT_CONNECTION] log.warning("Unregistering default connection '{0}'. Use set_default_connection to set a new one.".format(name)) diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index 286d8c1f..1b198739 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -31,9 +31,12 @@ log = logging.getLogger(__name__) def _clone_model_class(model, attrs): new_type = type(model.__name__, (model,), attrs) - new_type.__abstract__ = model.__abstract__ - new_type.__discriminator_value__ = model.__discriminator_value__ - new_type.__default_ttl__ = model.__default_ttl__ + try: + new_type.__abstract__ = model.__abstract__ + new_type.__discriminator_value__ = model.__discriminator_value__ + new_type.__default_ttl__ = model.__default_ttl__ + except AttributeError: + pass return new_type @@ -803,6 +806,8 @@ class BaseModel(object): def _inst_batch(self, batch): assert self._timeout is connection.NOT_SET, 'Setting both timeout and batch is not supported' + if self._connection: + raise CQLEngineException("Cannot specify a connection on model in batch mode.") self._batch = batch return self diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index 8b968c6f..f47b025f 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -302,18 +302,20 @@ class ContextQuery(object): self.models = [] if len(args) < 1: - raise CQLEngineException("No model provided.") + raise ValueError("No model provided.") keyspace = kwargs.pop('keyspace', None) connection = kwargs.pop('connection', None) if kwargs: - raise CQLEngineException("Unknown keyword argument(s): {0}".format( + raise ValueError("Unknown keyword argument(s): {0}".format( ','.join(kwargs.keys()))) for model in args: - if not issubclass(model, models.Model): - raise CQLEngineException("Models must be derived from base Model.") + try: + issubclass(model, models.Model) + except TypeError: + raise ValueError("Models must be derived from base Model.") m = models._clone_model_class(model, {}) @@ -390,7 +392,8 @@ class AbstractQuerySet(object): if self._batch: return self._batch.add_query(statement) else: - result = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=self._connection) + connection = self._connection if self._connection else self.model._get_connection() + result = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) if self._if_not_exists or self._if_exists or self._conditional: check_applied(result) return result diff --git a/tests/integration/cqlengine/test_connections.py b/tests/integration/cqlengine/test_connections.py new file mode 100644 index 00000000..f2f88b0f --- /dev/null +++ b/tests/integration/cqlengine/test_connections.py @@ -0,0 +1,389 @@ +# Copyright 2013-2016 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra import InvalidRequest +from cassandra.cluster import NoHostAvailable +from cassandra.cqlengine import columns, CQLEngineException +from cassandra.cqlengine import connection as conn +from cassandra.cqlengine.management import drop_keyspace, sync_table, drop_table, create_keyspace_simple +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import ContextQuery, BatchQuery +from tests.integration.cqlengine import setup_connection, DEFAULT_KEYSPACE +from tests.integration.cqlengine.base import BaseCassEngTestCase + + +class TestModel(Model): + + __keyspace__ = 'ks1' + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer() + text = columns.Text() + + +class AnotherTestModel(Model): + + __keyspace__ = 'ks1' + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer() + text = columns.Text() + + +class ContextQueryConnectionTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(ContextQueryConnectionTests, cls).setUpClass() + create_keyspace_simple('ks1', 1) + + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['127.0.0.100'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', ['127.0.0.1']) + + with ContextQuery(TestModel, connection='cluster') as tm: + sync_table(tm) + + @classmethod + def tearDownClass(cls): + super(ContextQueryConnectionTests, cls).tearDownClass() + + with ContextQuery(TestModel, connection='cluster') as tm: + drop_table(tm) + drop_keyspace('ks1', connections=['cluster']) + + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def test_context_connection_priority(self): + + # Set the default connection on the Model + TestModel.__connection__ = 'cluster' + with ContextQuery(TestModel) as tm: + tm.objects.create(partition=1, cluster=1) + + # ContextQuery connection should have priority over default one + with ContextQuery(TestModel, connection='fake_cluster') as tm: + with self.assertRaises(NoHostAvailable): + tm.objects.create(partition=1, cluster=1) + + # Explicit connection should have priority over ContextQuery one + with ContextQuery(TestModel, connection='fake_cluster') as tm: + tm.objects.using(connection='cluster').create(partition=1, cluster=1) + + # Reset the default conn of the model + TestModel.__connection__ = None + + # No model connection and an invalid default connection + with ContextQuery(TestModel) as tm: + with self.assertRaises(NoHostAvailable): + tm.objects.create(partition=1, cluster=1) + + def test_context_connection_with_keyspace(self): + + # ks2 doesn't exist + with ContextQuery(TestModel, connection='cluster', keyspace='ks2') as tm: + with self.assertRaises(InvalidRequest): + tm.objects.create(partition=1, cluster=1) + + +class ManagementConnectionTests(BaseCassEngTestCase): + + keyspaces = ['ks1', 'ks2'] + conns = ['cluster'] + + @classmethod + def setUpClass(cls): + super(ManagementConnectionTests, cls).setUpClass() + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['127.0.0.100'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', ['127.0.0.1']) + + + @classmethod + def tearDownClass(cls): + super(ManagementConnectionTests, cls).tearDownClass() + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def test_create_drop_keyspace(self): + + # No connection (default is fake) + with self.assertRaises(NoHostAvailable): + create_keyspace_simple(self.keyspaces[0], 1) + + # Explicit connections + for ks in self.keyspaces: + create_keyspace_simple(ks, 1, connections=self.conns) + + for ks in self.keyspaces: + drop_keyspace(ks, connections=self.conns) + + def test_create_drop_table(self): + + for ks in self.keyspaces: + create_keyspace_simple(ks, 1, connections=self.conns) + + # No connection (default is fake) + with self.assertRaises(NoHostAvailable): + sync_table(TestModel) + + # Explicit connections + sync_table(TestModel, connections=self.conns) + + # Explicit drop + drop_table(TestModel, connections=self.conns) + + # Model connection + TestModel.__connection__ = 'cluster' + sync_table(TestModel) + TestModel.__connection__ = None + + # No connection (default is fake) + with self.assertRaises(NoHostAvailable): + drop_table(TestModel) + + # Model connection + TestModel.__connection__ = 'cluster' + drop_table(TestModel) + TestModel.__connection__ = None + + # Model connection + for ks in self.keyspaces: + drop_keyspace(ks, connections=self.conns) + + +class BatchQueryConnectionTests(BaseCassEngTestCase): + + conns = ['cluster'] + + @classmethod + def setUpClass(cls): + super(BatchQueryConnectionTests, cls).setUpClass() + + create_keyspace_simple('ks1', 1) + sync_table(TestModel) + sync_table(AnotherTestModel) + + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['127.0.0.100'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', ['127.0.0.1']) + + + @classmethod + def tearDownClass(cls): + super(BatchQueryConnectionTests, cls).tearDownClass() + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + drop_keyspace('ks1') + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def test_basic_batch_query(self): + """Test BatchQuery requests""" + + # No connection with a QuerySet (default is a fake one) + with self.assertRaises(NoHostAvailable): + with BatchQuery() as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + + # Explicit connection with a QuerySet + with BatchQuery(connection='cluster') as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + + # Get an object from the BD + with ContextQuery(TestModel, connection='cluster') as tm: + obj = tm.objects.get(partition=1, cluster=1) + obj.__connection__ = None + + # No connection with a model (default is a fake one) + with self.assertRaises(NoHostAvailable): + with BatchQuery() as b: + obj.count = 2 + obj.batch(b).save() + + # Explicit connection with a model + with BatchQuery(connection='cluster') as b: + obj.count = 2 + obj.batch(b).save() + + def test_batch_query_different_connection(self): + """Test BatchQuery with Models that have a different connection""" + + # Testing on a model class + TestModel.__connection__ = 'cluster' + AnotherTestModel.__connection__ = 'cluster2' + + with self.assertRaises(CQLEngineException): + with BatchQuery() as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + AnotherTestModel.objects.batch(b).create(partition=1, cluster=1) + + TestModel.__connection__ = None + AnotherTestModel.__connection__ = None + + with BatchQuery(connection='cluster') as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + AnotherTestModel.objects.batch(b).create(partition=1, cluster=1) + + # Testing on a model instance + with ContextQuery(TestModel, AnotherTestModel, connection='cluster') as (tm, atm): + obj1 = tm.objects.get(partition=1, cluster=1) + obj2 = atm.objects.get(partition=1, cluster=1) + + obj1.__connection__ = 'cluster' + obj2.__connection__ = 'cluster2' + + obj1.count = 4 + obj2.count = 4 + + with self.assertRaises(CQLEngineException): + with BatchQuery() as b: + obj1.batch(b).save() + obj2.batch(b).save() + + def test_batch_query_connection_override(self): + """Test that we cannot override a BatchQuery connection per model""" + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + TestModel.batch(b).using(connection='test').save() + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + TestModel.using(connection='test').batch(b).save() + + with ContextQuery(TestModel, AnotherTestModel, connection='cluster') as (tm, atm): + obj1 = tm.objects.get(partition=1, cluster=1) + obj1.__connection__ = None + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + obj1.using(connection='test').batch(b).save() + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + obj1.batch(b).using(connection='test').save() + + +class UsingDescriptorTests(BaseCassEngTestCase): + + conns = ['cluster'] + keyspaces = ['ks1', 'ks2'] + + @classmethod + def setUpClass(cls): + super(UsingDescriptorTests, cls).setUpClass() + + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['127.0.0.100'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', ['127.0.0.1']) + + + @classmethod + def tearDownClass(cls): + super(UsingDescriptorTests, cls).tearDownClass() + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + for ks in cls.keyspaces: + drop_keyspace(ks) + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def _reset_data(self): + + for ks in self.keyspaces: + drop_keyspace(ks, connections=self.conns) + + for ks in self.keyspaces: + create_keyspace_simple(ks, 1, connections=self.conns) + sync_table(TestModel, keyspaces=self.keyspaces, connections=self.conns) + + def test_keyspace(self): + + self._reset_data() + + with ContextQuery(TestModel, connection='cluster') as tm: + + # keyspace Model class + tm.objects.using(keyspace='ks2').create(partition=1, cluster=1) + tm.objects.using(keyspace='ks2').create(partition=2, cluster=2) + + with self.assertRaises(TestModel.DoesNotExist): + tm.objects.get(partition=1, cluster=1) # default keyspace ks1 + obj1 = tm.objects.using(keyspace='ks2').get(partition=1, cluster=1) + + obj1.count = 2 + obj1.save() + + with self.assertRaises(NoHostAvailable): + TestModel.objects.using(keyspace='ks2').get(partition=1, cluster=1) + + obj2 = TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=1, cluster=1) + self.assertEqual(obj2.count, 2) + + # Update test + TestModel.objects(partition=2, cluster=2).using(connection='cluster', keyspace='ks2').update(count=5) + obj3 = TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=2, cluster=2) + self.assertEqual(obj3.count, 5) + + TestModel.objects(partition=2, cluster=2).using(connection='cluster', keyspace='ks2').delete() + with self.assertRaises(TestModel.DoesNotExist): + TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=2, cluster=2) + + def test_connection(self): + + self._reset_data() + + # Model class + with self.assertRaises(NoHostAvailable): + TestModel.objects.create(partition=1, cluster=1) + + TestModel.objects.using(connection='cluster').create(partition=1, cluster=1) + TestModel.objects(partition=1, cluster=1).using(connection='cluster').update(count=2) + obj1 = TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) + self.assertEqual(obj1.count, 2) + + obj1.using(connection='cluster').update(count=5) + obj1 = TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) + self.assertEqual(obj1.count, 5) + + obj1.using(connection='cluster').delete() + with self.assertRaises(TestModel.DoesNotExist): + TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) diff --git a/tests/integration/cqlengine/test_context_query.py b/tests/integration/cqlengine/test_context_query.py index b3941319..0a29688d 100644 --- a/tests/integration/cqlengine/test_context_query.py +++ b/tests/integration/cqlengine/test_context_query.py @@ -46,6 +46,7 @@ class ContextQueryTests(BaseCassEngTestCase): for ks in cls.KEYSPACES: drop_keyspace(ks) + def setUp(self): super(ContextQueryTests, self).setUp() for ks in self.KEYSPACES: @@ -125,3 +126,50 @@ class ContextQueryTests(BaseCassEngTestCase): self.assertEqual(42, tm.objects.get(partition=1).count) + def test_context_multiple_models(self): + """ + Tests the use of multiple models with the context manager + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result all models are properly updated with the context + + @test_category query + """ + + with ContextQuery(TestModel, TestModel, keyspace='ks4') as (tm1, tm2): + + self.assertNotEqual(tm1, tm2) + self.assertEqual(tm1.__keyspace__, 'ks4') + self.assertEqual(tm2.__keyspace__, 'ks4') + + def test_context_invalid_parameters(self): + """ + Tests that invalid parameters are raised by the context manager + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result a ValueError is raised when passing invalid parameters + + @test_category query + """ + + with self.assertRaises(ValueError): + with ContextQuery(keyspace='ks2'): + pass + + with self.assertRaises(ValueError): + with ContextQuery(42) as tm: + pass + + with self.assertRaises(ValueError): + with ContextQuery(TestModel, 42): + pass + + with self.assertRaises(ValueError): + with ContextQuery(TestModel, unknown_param=42): + pass + + with self.assertRaises(ValueError): + with ContextQuery(TestModel, keyspace='ks2', unknown_param=42): + pass \ No newline at end of file