From f1fb9da47c867f4edd57327de0f4a3a773fd0068 Mon Sep 17 00:00:00 2001 From: Jon Haddad Date: Thu, 24 Oct 2013 15:36:20 -0700 Subject: [PATCH] ensure consistency is called with the right param --- cqlengine/connection.py | 11 +++++---- cqlengine/tests/test_consistency.py | 35 +++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 cqlengine/tests/test_consistency.py diff --git a/cqlengine/connection.py b/cqlengine/connection.py index f6e25e4e..df89e430 100644 --- a/cqlengine/connection.py +++ b/cqlengine/connection.py @@ -165,10 +165,10 @@ class ConnectionPool(object): from thrift.transport import TSocket, TTransport thrift_socket = TSocket.TSocket(host.name, host.port) - + if self._timeout is not None: thrift_socket.setTimeout(self._timeout) - + return TTransport.TFramedTransport(thrift_socket) def _create_connection(self): @@ -202,14 +202,17 @@ class ConnectionPool(object): raise CQLConnectionError("Could not connect to any server in cluster") - def execute(self, query, params): + def execute(self, query, params, consistency_level=None): + if not consistency_level: + consistency_level = self._consistency + while True: try: con = self.get() if not con: raise CQLEngineException("Error calling execute without calling setup.") cur = con.cursor() - cur.execute(query, params) + cur.execute(query, params, consistency_level=consistency_level) columns = [i[0] for i in cur.description or []] results = [RowResult(r) for r in cur.fetchall()] LOG.debug('{} {}'.format(query, repr(params))) diff --git a/cqlengine/tests/test_consistency.py b/cqlengine/tests/test_consistency.py new file mode 100644 index 00000000..cf1673ed --- /dev/null +++ b/cqlengine/tests/test_consistency.py @@ -0,0 +1,35 @@ +from cqlengine.management import sync_table, drop_table +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.models import Model +from uuid import uuid4 +from cqlengine import columns +import mock +from cqlengine.connection import ConnectionPool +from cqlengine import ALL + +class TestConsistencyModel(Model): + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + +class BaseConsistencyTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseConsistencyTest, cls).setUpClass() + sync_table(TestConsistencyModel) + + @classmethod + def tearDownClass(cls): + super(BaseConsistencyTest, cls).tearDownClass() + drop_table(TestConsistencyModel) + + +class TestConsistency(BaseConsistencyTest): + def test_create_uses_consistency(self): + + with mock.patch.object(ConnectionPool, 'execute') as m: + TestConsistencyModel.consistency(ALL).create(text="i am not fault tolerant this way") + + args = m.call_args + self.assertEqual(ALL, args[2])