diff --git a/cqlengine/tests/base.py b/cqlengine/tests/base.py index dea0cc19..5bd66b64 100644 --- a/cqlengine/tests/base.py +++ b/cqlengine/tests/base.py @@ -4,7 +4,10 @@ import sys import six from cqlengine.connection import get_session + CASSANDRA_VERSION = int(os.environ['CASSANDRA_VERSION']) +PROTOCOL_VERSION = 1 if CASSANDRA_VERSION < 20 else 2 + class BaseCassEngTestCase(TestCase): diff --git a/cqlengine/tests/columns/test_static_column.py b/cqlengine/tests/columns/test_static_column.py index 34c553dc..e8cfdedb 100644 --- a/cqlengine/tests/columns/test_static_column.py +++ b/cqlengine/tests/columns/test_static_column.py @@ -4,10 +4,9 @@ from cqlengine import Model from cqlengine import columns from cqlengine.management import sync_table, drop_table from cqlengine.models import ModelDefinitionException -from cqlengine.tests.base import BaseCassEngTestCase, CASSANDRA_VERSION -from cqlengine.connection import get_cluster +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.tests.base import CASSANDRA_VERSION, PROTOCOL_VERSION -cluster = get_cluster() class TestStaticModel(Model): __keyspace__ = 'test' @@ -31,7 +30,7 @@ class TestStaticColumn(BaseCassEngTestCase): super(TestStaticColumn, cls).tearDownClass() drop_table(TestStaticModel) - @skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") + @skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_mixed_updates(self): """ Tests that updates on both static and non-static columns work as intended """ instance = TestStaticModel.create() @@ -47,7 +46,7 @@ class TestStaticColumn(BaseCassEngTestCase): assert actual.static == "it's still shared" - @skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") + @skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_static_only_updates(self): """ Tests that updates on static only column work as intended """ instance = TestStaticModel.create() diff --git a/cqlengine/tests/management/test_management.py b/cqlengine/tests/management/test_management.py index 3bda88ed..55258e52 100644 --- a/cqlengine/tests/management/test_management.py +++ b/cqlengine/tests/management/test_management.py @@ -3,14 +3,14 @@ from cqlengine import ALL, CACHING_ALL, CACHING_NONE from cqlengine.connection import get_session from cqlengine.exceptions import CQLEngineException from cqlengine.management import get_fields, sync_table, drop_table -from cqlengine.tests.base import BaseCassEngTestCase, CASSANDRA_VERSION +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.tests.base import CASSANDRA_VERSION, PROTOCOL_VERSION from cqlengine import management from cqlengine.tests.query.test_queryset import TestModel from cqlengine.models import Model from cqlengine import columns, SizeTieredCompactionStrategy, LeveledCompactionStrategy from unittest import skipUnless -from cqlengine.connection import get_cluster -cluster = get_cluster() + class CreateKeyspaceTest(BaseCassEngTestCase): def test_create_succeeeds(self): @@ -256,7 +256,7 @@ class NonModelFailureTest(BaseCassEngTestCase): sync_table(self.FakeModel) -@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") +@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_static_columns(): class StaticModel(Model): id = columns.Integer(primary_key=True) diff --git a/cqlengine/tests/query/test_queryset.py b/cqlengine/tests/query/test_queryset.py index 00b1be49..8383ccdf 100644 --- a/cqlengine/tests/query/test_queryset.py +++ b/cqlengine/tests/query/test_queryset.py @@ -20,10 +20,8 @@ from datetime import tzinfo from cqlengine import statements from cqlengine import operators - -from cqlengine.connection import get_cluster, get_session - -cluster = get_cluster() +from cqlengine.connection import get_session +from cqlengine.tests.base import PROTOCOL_VERSION class TzOffset(tzinfo): @@ -690,7 +688,7 @@ class TestObjectsProperty(BaseQuerySetUsage): assert TestModel.objects._result_cache is None -@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") +@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_paged_result_handling(): # addresses #225 class PagingTest(Model): diff --git a/cqlengine/tests/test_ifnotexists.py b/cqlengine/tests/test_ifnotexists.py index 202a1c03..6749ae88 100644 --- a/cqlengine/tests/test_ifnotexists.py +++ b/cqlengine/tests/test_ifnotexists.py @@ -1,14 +1,12 @@ from unittest import skipUnless from cqlengine.management import sync_table, drop_table, create_keyspace, delete_keyspace from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.tests.base import PROTOCOL_VERSION from cqlengine.models import Model from cqlengine.exceptions import LWTException from cqlengine import columns, BatchQuery from uuid import uuid4 import mock -from cqlengine.connection import get_cluster - -cluster = get_cluster() class TestIfNotExistsModel(Model): @@ -43,7 +41,7 @@ class BaseIfNotExistsTest(BaseCassEngTestCase): class IfNotExistsInsertTests(BaseIfNotExistsTest): - @skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") + @skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_insert_if_not_exists_success(self): """ tests that insertion with if_not_exists work as expected """ @@ -77,7 +75,7 @@ class IfNotExistsInsertTests(BaseIfNotExistsTest): self.assertEquals(tm.count, 9) self.assertEquals(tm.text, '111111111111') - @skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") + @skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_insert_if_not_exists_success(self): """ tests that batch insertion with if_not_exists work as expected """