use a constant to get protocol_version instead of module-level calls to get_cluster

This commit is contained in:
Amy Hanlon
2015-01-04 17:56:17 -05:00
parent 849b569bb6
commit 3d06858b80
5 changed files with 17 additions and 19 deletions

View File

@@ -4,7 +4,10 @@ import sys
import six import six
from cqlengine.connection import get_session from cqlengine.connection import get_session
CASSANDRA_VERSION = int(os.environ['CASSANDRA_VERSION']) CASSANDRA_VERSION = int(os.environ['CASSANDRA_VERSION'])
PROTOCOL_VERSION = 1 if CASSANDRA_VERSION < 20 else 2
class BaseCassEngTestCase(TestCase): class BaseCassEngTestCase(TestCase):

View File

@@ -4,10 +4,9 @@ from cqlengine import Model
from cqlengine import columns from cqlengine import columns
from cqlengine.management import sync_table, drop_table from cqlengine.management import sync_table, drop_table
from cqlengine.models import ModelDefinitionException from cqlengine.models import ModelDefinitionException
from cqlengine.tests.base import BaseCassEngTestCase, CASSANDRA_VERSION from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.connection import get_cluster from cqlengine.tests.base import CASSANDRA_VERSION, PROTOCOL_VERSION
cluster = get_cluster()
class TestStaticModel(Model): class TestStaticModel(Model):
__keyspace__ = 'test' __keyspace__ = 'test'
@@ -31,7 +30,7 @@ class TestStaticColumn(BaseCassEngTestCase):
super(TestStaticColumn, cls).tearDownClass() super(TestStaticColumn, cls).tearDownClass()
drop_table(TestStaticModel) drop_table(TestStaticModel)
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") @skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_mixed_updates(self): def test_mixed_updates(self):
""" Tests that updates on both static and non-static columns work as intended """ """ Tests that updates on both static and non-static columns work as intended """
instance = TestStaticModel.create() instance = TestStaticModel.create()
@@ -47,7 +46,7 @@ class TestStaticColumn(BaseCassEngTestCase):
assert actual.static == "it's still shared" assert actual.static == "it's still shared"
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") @skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_static_only_updates(self): def test_static_only_updates(self):
""" Tests that updates on static only column work as intended """ """ Tests that updates on static only column work as intended """
instance = TestStaticModel.create() instance = TestStaticModel.create()

View File

@@ -3,14 +3,14 @@ from cqlengine import ALL, CACHING_ALL, CACHING_NONE
from cqlengine.connection import get_session from cqlengine.connection import get_session
from cqlengine.exceptions import CQLEngineException from cqlengine.exceptions import CQLEngineException
from cqlengine.management import get_fields, sync_table, drop_table from cqlengine.management import get_fields, sync_table, drop_table
from cqlengine.tests.base import BaseCassEngTestCase, CASSANDRA_VERSION from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.tests.base import CASSANDRA_VERSION, PROTOCOL_VERSION
from cqlengine import management from cqlengine import management
from cqlengine.tests.query.test_queryset import TestModel from cqlengine.tests.query.test_queryset import TestModel
from cqlengine.models import Model from cqlengine.models import Model
from cqlengine import columns, SizeTieredCompactionStrategy, LeveledCompactionStrategy from cqlengine import columns, SizeTieredCompactionStrategy, LeveledCompactionStrategy
from unittest import skipUnless from unittest import skipUnless
from cqlengine.connection import get_cluster
cluster = get_cluster()
class CreateKeyspaceTest(BaseCassEngTestCase): class CreateKeyspaceTest(BaseCassEngTestCase):
def test_create_succeeeds(self): def test_create_succeeeds(self):
@@ -256,7 +256,7 @@ class NonModelFailureTest(BaseCassEngTestCase):
sync_table(self.FakeModel) sync_table(self.FakeModel)
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") @skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_static_columns(): def test_static_columns():
class StaticModel(Model): class StaticModel(Model):
id = columns.Integer(primary_key=True) id = columns.Integer(primary_key=True)

View File

@@ -20,10 +20,8 @@ from datetime import tzinfo
from cqlengine import statements from cqlengine import statements
from cqlengine import operators from cqlengine import operators
from cqlengine.connection import get_session
from cqlengine.connection import get_cluster, get_session from cqlengine.tests.base import PROTOCOL_VERSION
cluster = get_cluster()
class TzOffset(tzinfo): class TzOffset(tzinfo):
@@ -690,7 +688,7 @@ class TestObjectsProperty(BaseQuerySetUsage):
assert TestModel.objects._result_cache is None assert TestModel.objects._result_cache is None
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") @skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_paged_result_handling(): def test_paged_result_handling():
# addresses #225 # addresses #225
class PagingTest(Model): class PagingTest(Model):

View File

@@ -1,14 +1,12 @@
from unittest import skipUnless from unittest import skipUnless
from cqlengine.management import sync_table, drop_table, create_keyspace, delete_keyspace from cqlengine.management import sync_table, drop_table, create_keyspace, delete_keyspace
from cqlengine.tests.base import BaseCassEngTestCase from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.tests.base import PROTOCOL_VERSION
from cqlengine.models import Model from cqlengine.models import Model
from cqlengine.exceptions import LWTException from cqlengine.exceptions import LWTException
from cqlengine import columns, BatchQuery from cqlengine import columns, BatchQuery
from uuid import uuid4 from uuid import uuid4
import mock import mock
from cqlengine.connection import get_cluster
cluster = get_cluster()
class TestIfNotExistsModel(Model): class TestIfNotExistsModel(Model):
@@ -43,7 +41,7 @@ class BaseIfNotExistsTest(BaseCassEngTestCase):
class IfNotExistsInsertTests(BaseIfNotExistsTest): class IfNotExistsInsertTests(BaseIfNotExistsTest):
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") @skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_insert_if_not_exists_success(self): def test_insert_if_not_exists_success(self):
""" tests that insertion with if_not_exists work as expected """ """ tests that insertion with if_not_exists work as expected """
@@ -77,7 +75,7 @@ class IfNotExistsInsertTests(BaseIfNotExistsTest):
self.assertEquals(tm.count, 9) self.assertEquals(tm.count, 9)
self.assertEquals(tm.text, '111111111111') self.assertEquals(tm.text, '111111111111')
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") @skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_batch_insert_if_not_exists_success(self): def test_batch_insert_if_not_exists_success(self):
""" tests that batch insertion with if_not_exists work as expected """ """ tests that batch insertion with if_not_exists work as expected """