use a constant to get protocol_version instead of module-level calls to get_cluster

This commit is contained in:
Amy Hanlon
2015-01-04 17:56:17 -05:00
parent 849b569bb6
commit 3d06858b80
5 changed files with 17 additions and 19 deletions

View File

@@ -4,7 +4,10 @@ import sys
import six
from cqlengine.connection import get_session
CASSANDRA_VERSION = int(os.environ['CASSANDRA_VERSION'])
PROTOCOL_VERSION = 1 if CASSANDRA_VERSION < 20 else 2
class BaseCassEngTestCase(TestCase):

View File

@@ -4,10 +4,9 @@ from cqlengine import Model
from cqlengine import columns
from cqlengine.management import sync_table, drop_table
from cqlengine.models import ModelDefinitionException
from cqlengine.tests.base import BaseCassEngTestCase, CASSANDRA_VERSION
from cqlengine.connection import get_cluster
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.tests.base import CASSANDRA_VERSION, PROTOCOL_VERSION
cluster = get_cluster()
class TestStaticModel(Model):
__keyspace__ = 'test'
@@ -31,7 +30,7 @@ class TestStaticColumn(BaseCassEngTestCase):
super(TestStaticColumn, cls).tearDownClass()
drop_table(TestStaticModel)
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_mixed_updates(self):
""" Tests that updates on both static and non-static columns work as intended """
instance = TestStaticModel.create()
@@ -47,7 +46,7 @@ class TestStaticColumn(BaseCassEngTestCase):
assert actual.static == "it's still shared"
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_static_only_updates(self):
""" Tests that updates on static only column work as intended """
instance = TestStaticModel.create()

View File

@@ -3,14 +3,14 @@ from cqlengine import ALL, CACHING_ALL, CACHING_NONE
from cqlengine.connection import get_session
from cqlengine.exceptions import CQLEngineException
from cqlengine.management import get_fields, sync_table, drop_table
from cqlengine.tests.base import BaseCassEngTestCase, CASSANDRA_VERSION
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.tests.base import CASSANDRA_VERSION, PROTOCOL_VERSION
from cqlengine import management
from cqlengine.tests.query.test_queryset import TestModel
from cqlengine.models import Model
from cqlengine import columns, SizeTieredCompactionStrategy, LeveledCompactionStrategy
from unittest import skipUnless
from cqlengine.connection import get_cluster
cluster = get_cluster()
class CreateKeyspaceTest(BaseCassEngTestCase):
def test_create_succeeeds(self):
@@ -256,7 +256,7 @@ class NonModelFailureTest(BaseCassEngTestCase):
sync_table(self.FakeModel)
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_static_columns():
class StaticModel(Model):
id = columns.Integer(primary_key=True)

View File

@@ -20,10 +20,8 @@ from datetime import tzinfo
from cqlengine import statements
from cqlengine import operators
from cqlengine.connection import get_cluster, get_session
cluster = get_cluster()
from cqlengine.connection import get_session
from cqlengine.tests.base import PROTOCOL_VERSION
class TzOffset(tzinfo):
@@ -690,7 +688,7 @@ class TestObjectsProperty(BaseQuerySetUsage):
assert TestModel.objects._result_cache is None
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_paged_result_handling():
# addresses #225
class PagingTest(Model):

View File

@@ -1,14 +1,12 @@
from unittest import skipUnless
from cqlengine.management import sync_table, drop_table, create_keyspace, delete_keyspace
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.tests.base import PROTOCOL_VERSION
from cqlengine.models import Model
from cqlengine.exceptions import LWTException
from cqlengine import columns, BatchQuery
from uuid import uuid4
import mock
from cqlengine.connection import get_cluster
cluster = get_cluster()
class TestIfNotExistsModel(Model):
@@ -43,7 +41,7 @@ class BaseIfNotExistsTest(BaseCassEngTestCase):
class IfNotExistsInsertTests(BaseIfNotExistsTest):
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_insert_if_not_exists_success(self):
""" tests that insertion with if_not_exists work as expected """
@@ -77,7 +75,7 @@ class IfNotExistsInsertTests(BaseIfNotExistsTest):
self.assertEquals(tm.count, 9)
self.assertEquals(tm.text, '111111111111')
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_batch_insert_if_not_exists_success(self):
""" tests that batch insertion with if_not_exists work as expected """