diff --git a/cqlengine/connection.py b/cqlengine/connection.py index 1f3be211..c7274e81 100644 --- a/cqlengine/connection.py +++ b/cqlengine/connection.py @@ -24,6 +24,7 @@ from contextlib import contextmanager from thrift.transport.TTransport import TTransportException from cqlengine.statements import BaseCQLStatement +from cassandra.query import dict_factory LOG = logging.getLogger('cqlengine.cql') @@ -83,6 +84,7 @@ def setup( cluster = Cluster(hosts) session = cluster.connect() + session.row_factory = dict_factory _max_connections = max_connections @@ -273,8 +275,14 @@ def execute_native(query, params=None, consistency_level=None): query = str(query) params = params or {} result = session.execute(query, params) - import ipdb; ipdb.set_trace() - return ([], result) + + if result: + keys = result[0].keys() + else: + keys = [] + + return QueryResult(keys, result) + def get_session(): return session diff --git a/cqlengine/management.py b/cqlengine/management.py index c0562bab..14a0203f 100644 --- a/cqlengine/management.py +++ b/cqlengine/management.py @@ -229,13 +229,13 @@ def get_fields(model): # Tables containing only primary keys do not appear to create # any entries in system.schema_columns, as only non-primary-key attributes # appear to be inserted into the schema_columns table - if not tmp: + if not tmp[1]: return [] try: - return [Field(x.column_name, x.validator) for x in tmp if x.type == 'regular'] + return [Field(x['column_name'], x['validator']) for x in tmp[1] if x['type'] == 'regular'] except ValueError: - return [Field(x.column_name, x.validator) for x in tmp] + return [Field(x['column_name'], x['validator']) for x in tmp[1]] # convert to Field named tuples diff --git a/cqlengine/query.py b/cqlengine/query.py index ef92f923..94ae9867 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -232,7 +232,6 @@ class AbstractQuerySet(object): return self._batch.add_query(q) else: result = execute_native(q, consistency_level=self._consistency) - import ipdb; ipdb.set_trace() return result def __unicode__(self): diff --git a/cqlengine/tests/statements/test_insert_statement.py b/cqlengine/tests/statements/test_insert_statement.py index eeffe556..32bb7435 100644 --- a/cqlengine/tests/statements/test_insert_statement.py +++ b/cqlengine/tests/statements/test_insert_statement.py @@ -17,7 +17,7 @@ class InsertStatementTests(TestCase): self.assertEqual( unicode(ist), - 'INSERT INTO table ("a", "c") VALUES (?, ?)' + 'INSERT INTO table ("a", "c") VALUES (%(0)s, %(1)s)' ) def test_context_update(self): @@ -28,7 +28,7 @@ class InsertStatementTests(TestCase): ist.update_context_id(4) self.assertEqual( unicode(ist), - 'INSERT INTO table ("a", "c") VALUES (?, ?)' + 'INSERT INTO table ("a", "c") VALUES (%(4)s, %(5)s)' ) ctx = ist.get_context() self.assertEqual(ctx, {'4': 'b', '5': 'd'}) diff --git a/cqlengine/tests/test_ttl.py b/cqlengine/tests/test_ttl.py index c08a1a19..b55254a5 100644 --- a/cqlengine/tests/test_ttl.py +++ b/cqlengine/tests/test_ttl.py @@ -4,7 +4,8 @@ from cqlengine.models import Model from uuid import uuid4 from cqlengine import columns import mock -from cqlengine.connection import ConnectionPool +from cqlengine.connection import ConnectionPool, get_session + class TestTTLModel(Model): id = columns.UUID(primary_key=True, default=lambda:uuid4()) @@ -39,7 +40,9 @@ class TTLModelTests(BaseTTLTest): def test_ttl_included_on_create(self): """ tests that ttls on models work as expected """ - with mock.patch.object(ConnectionPool, 'execute') as m: + session = get_session() + + with mock.patch.object(session, 'execute') as m: TestTTLModel.ttl(60).create(text="hello blake") query = m.call_args[0][0] @@ -56,8 +59,10 @@ class TTLModelTests(BaseTTLTest): class TTLInstanceUpdateTest(BaseTTLTest): def test_update_includes_ttl(self): + session = get_session() + model = TestTTLModel.create(text="goodbye blake") - with mock.patch.object(ConnectionPool, 'execute') as m: + with mock.patch.object(session, 'execute') as m: model.ttl(60).update(text="goodbye forever") query = m.call_args[0][0] @@ -84,22 +89,27 @@ class TTLInstanceTest(BaseTTLTest): self.assertEqual(60, o._ttl) def test_ttl_is_include_with_query_on_update(self): + session = get_session() + o = TestTTLModel.create(text="whatever") o.text = "new stuff" o = o.ttl(60) - with mock.patch.object(ConnectionPool, 'execute') as m: + with mock.patch.object(session, 'execute') as m: o.save() + query = m.call_args[0][0] self.assertIn("USING TTL", query) class TTLBlindUpdateTest(BaseTTLTest): def test_ttl_included_with_blind_update(self): + session = get_session() + o = TestTTLModel.create(text="whatever") tid = o.id - with mock.patch.object(ConnectionPool, 'execute') as m: + with mock.patch.object(session, 'execute') as m: TestTTLModel.objects(id=tid).ttl(60).update(text="bacon") query = m.call_args[0][0]