diff --git a/.gitignore b/.gitignore index 6d96b7f0..1e4d8eb3 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,6 @@ html/ #Mr Developer .mr.developer.cfg .noseids +/commitlog +/data + diff --git a/cqlengine/connection.py b/cqlengine/connection.py index a1a18df0..a7b700af 100644 --- a/cqlengine/connection.py +++ b/cqlengine/connection.py @@ -9,8 +9,11 @@ import random import cql import logging +from copy import copy from cqlengine.exceptions import CQLEngineException +from contextlib import contextmanager + from thrift.transport.TTransport import TTransportException LOG = logging.getLogger('cqlengine.cql') @@ -20,7 +23,9 @@ class CQLConnectionError(CQLEngineException): pass Host = namedtuple('Host', ['name', 'port']) _max_connections = 10 -_connection_pool = None + +# global connection pool +connection_pool = None def setup(hosts, username=None, password=None, max_connections=10, default_keyspace=None): """ @@ -29,7 +34,7 @@ def setup(hosts, username=None, password=None, max_connections=10, default_keysp :param hosts: list of hosts, strings in the :, or just """ global _max_connections - global _connection_pool + global connection_pool _max_connections = max_connections if default_keyspace: @@ -50,16 +55,13 @@ def setup(hosts, username=None, password=None, max_connections=10, default_keysp if not _hosts: raise CQLConnectionError("At least one host required") - _connection_pool = ConnectionPool(_hosts) + connection_pool = ConnectionPool(_hosts, username, password) class ConnectionPool(object): """Handles pooling of database connections.""" - # Connection pool queue - _queue = None - - def __init__(self, hosts, username, password): + def __init__(self, hosts, username=None, password=None): self._hosts = hosts self._username = username self._password = password @@ -113,58 +115,40 @@ class ConnectionPool(object): if not self._hosts: raise CQLConnectionError("At least one host required") - host = _hosts[_host_idx] + hosts = copy(self._hosts) + random.shuffle(hosts) - new_conn = cql.connect(host.name, host.port, user=_username, password=_password) - new_conn.set_cql_version('3.0.0') - return new_conn - - -class connection_manager(object): - """ - Connection failure tolerant connection manager. Written to be used in a 'with' block for connection pooling - """ - def __init__(self): - if not _hosts: - raise CQLConnectionError("No connections have been configured, call cqlengine.connection.setup") - self.keyspace = None - self.con = ConnectionPool.get() - self.cur = None - - def close(self): - if self.cur: self.cur.close() - ConnectionPool.put(self.con) - - def __enter__(self): - return self - - def __exit__(self, type, value, traceback): - self.close() - - def execute(self, query, params={}): - """ - Gets a connection from the pool and executes the given query, returns the cursor - - if there's a connection problem, this will silently create a new connection pool - from the available hosts, and remove the problematic host from the host list - """ - global _host_idx - - for i in range(len(_hosts)): + for host in hosts: try: - LOG.debug('{} {}'.format(query, repr(params))) - self.cur = self.con.cursor() - self.cur.execute(query, params) - return self.cur - except cql.ProgrammingError as ex: - raise CQLEngineException(unicode(ex)) - except TTransportException: - #TODO: check for other errors raised in the event of a connection / server problem - #move to the next connection and set the connection pool - _host_idx += 1 - _host_idx %= len(_hosts) - self.con.close() - self.con = ConnectionPool._create_connection() + new_conn = cql.connect(host.name, host.port, user=self._username, password=self._password) + new_conn.set_cql_version('3.0.0') + return new_conn + except Exception as e: + logging.debug("Could not establish connection to {}:{}".format(host.name, host.port)) + pass - raise CQLConnectionError("couldn't reach a Cassandra server") + raise CQLConnectionError("Could not connect to any server in cluster") + def execute(self, query, params): + try: + con = self.get() + cur = con.cursor() + cur.execute(query, params) + self.put(con) + return cur + except cql.ProgrammingError as ex: + raise CQLEngineException(unicode(ex)) + except TTransportException: + pass + + raise CQLEngineException("Could not execute query against the cluster") + +def execute(query, params={}): + return connection_pool.execute(query, params) + +@contextmanager +def connection_manager(): + global connection_pool + tmp = connection_pool.get() + yield tmp + connection_pool.put(tmp) diff --git a/cqlengine/management.py b/cqlengine/management.py index 25717aa4..feddef2f 100644 --- a/cqlengine/management.py +++ b/cqlengine/management.py @@ -1,6 +1,6 @@ import json -from cqlengine.connection import connection_manager +from cqlengine.connection import connection_manager, execute from cqlengine.exceptions import CQLEngineException def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, durable_writes=True, **replication_values): @@ -15,11 +15,11 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, """ with connection_manager() as con: #TODO: check system tables instead of using cql thrifteries - if not any([name == k.name for k in con.con.client.describe_keyspaces()]): -# if name not in [k.name for k in con.con.client.describe_keyspaces()]: + if not any([name == k.name for k in con.client.describe_keyspaces()]): + # if name not in [k.name for k in con.con.client.describe_keyspaces()]: try: #Try the 1.1 method - con.execute("""CREATE KEYSPACE {} + execute("""CREATE KEYSPACE {} WITH strategy_class = '{}' AND strategy_options:replication_factor={};""".format(name, strategy_class, replication_factor)) except CQLEngineException: @@ -38,12 +38,12 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, if strategy_class != 'SimpleStrategy': query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false') - con.execute(query) + execute(query) def delete_keyspace(name): with connection_manager() as con: - if name in [k.name for k in con.con.client.describe_keyspaces()]: - con.execute("DROP KEYSPACE {}".format(name)) + if name in [k.name for k in con.client.describe_keyspaces()]: + execute("DROP KEYSPACE {}".format(name)) def create_table(model, create_missing_keyspace=True): #construct query string @@ -55,78 +55,81 @@ def create_table(model, create_missing_keyspace=True): create_keyspace(model._get_keyspace()) with connection_manager() as con: - #check for an existing column family - #TODO: check system tables instead of using cql thrifteries - ks_info = con.con.client.describe_keyspace(model._get_keyspace()) - if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]): - qs = ['CREATE TABLE {}'.format(cf_name)] + ks_info = con.client.describe_keyspace(model._get_keyspace()) - #add column types - pkeys = [] - ckeys = [] - qtypes = [] - def add_column(col): - s = col.get_column_def() - if col.primary_key: - keys = (pkeys if col.partition_key else ckeys) - keys.append('"{}"'.format(col.db_field_name)) - qtypes.append(s) - for name, col in model._columns.items(): - add_column(col) + #check for an existing column family + #TODO: check system tables instead of using cql thrifteries + if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]): + qs = ['CREATE TABLE {}'.format(cf_name)] - qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or '')) - - qs += ['({})'.format(', '.join(qtypes))] + #add column types + pkeys = [] + ckeys = [] + qtypes = [] + def add_column(col): + s = col.get_column_def() + if col.primary_key: + keys = (pkeys if col.partition_key else ckeys) + keys.append('"{}"'.format(col.db_field_name)) + qtypes.append(s) + for name, col in model._columns.items(): + add_column(col) - with_qs = ['read_repair_chance = {}'.format(model.read_repair_chance)] + qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or '')) - _order = ["%s %s" % (c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()] - if _order: - with_qs.append("clustering order by ({})".format(', '.join(_order))) + qs += ['({})'.format(', '.join(qtypes))] - # add read_repair_chance - qs += ['WITH {}'.format(' AND '.join(with_qs))] + with_qs = ['read_repair_chance = {}'.format(model.read_repair_chance)] + + _order = ["%s %s" % (c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()] + if _order: + with_qs.append("clustering order by ({})".format(', '.join(_order))) + + # add read_repair_chance + qs += ['WITH {}'.format(' AND '.join(with_qs))] + qs = ' '.join(qs) + + try: + execute(qs) + except CQLEngineException as ex: + # 1.2 doesn't return cf names, so we have to examine the exception + # and ignore if it says the column family already exists + if "Cannot add already existing column family" not in unicode(ex): + raise + + #get existing index names, skip ones that already exist + with connection_manager() as con: + ks_info = con.client.describe_keyspace(model._get_keyspace()) + + cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name] + idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else [] + idx_names = filter(None, idx_names) + + indexes = [c for n,c in model._columns.items() if c.index] + if indexes: + for column in indexes: + if column.db_index_name in idx_names: continue + qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)] + qs += ['ON {}'.format(cf_name)] + qs += ['("{}")'.format(column.db_field_name)] qs = ' '.join(qs) try: - con.execute(qs) + execute(qs) except CQLEngineException as ex: # 1.2 doesn't return cf names, so we have to examine the exception - # and ignore if it says the column family already exists - if "Cannot add already existing column family" not in unicode(ex): + # and ignore if it says the index already exists + if "Index already exists" not in unicode(ex): raise - #get existing index names, skip ones that already exist - ks_info = con.con.client.describe_keyspace(model._get_keyspace()) - cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name] - idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else [] - idx_names = filter(None, idx_names) - - indexes = [c for n,c in model._columns.items() if c.index] - if indexes: - for column in indexes: - if column.db_index_name in idx_names: continue - qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)] - qs += ['ON {}'.format(cf_name)] - qs += ['("{}")'.format(column.db_field_name)] - qs = ' '.join(qs) - - try: - con.execute(qs) - except CQLEngineException as ex: - # 1.2 doesn't return cf names, so we have to examine the exception - # and ignore if it says the index already exists - if "Index already exists" not in unicode(ex): - raise - def delete_table(model): cf_name = model.column_family_name() - with connection_manager() as con: - try: - con.execute('drop table {};'.format(cf_name)) - except CQLEngineException as ex: - #don't freak out if the table doesn't exist - if 'Cannot drop non existing column family' not in unicode(ex): - raise + + try: + execute('drop table {};'.format(cf_name)) + except CQLEngineException as ex: + #don't freak out if the table doesn't exist + if 'Cannot drop non existing column family' not in unicode(ex): + raise diff --git a/cqlengine/query.py b/cqlengine/query.py index 3ab71129..d3ae8352 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -6,7 +6,8 @@ from time import time from uuid import uuid1 from cqlengine import BaseContainerColumn, BaseValueManager, Map, columns -from cqlengine.connection import connection_manager +from cqlengine.connection import connection_pool, connection_manager, execute + from cqlengine.exceptions import CQLEngineException from cqlengine.functions import QueryValue, Token @@ -193,8 +194,7 @@ class BatchQuery(object): query_list.append('APPLY BATCH;') - with connection_manager() as con: - con.execute('\n'.join(query_list), parameters) + execute('\n'.join(query_list), parameters) self.queries = [] @@ -346,8 +346,7 @@ class QuerySet(object): if self._batch: raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") if self._result_cache is None: - self._con = connection_manager() - self._cur = self._con.execute(self._select_query(), self._where_values()) + self._cur = execute(self._select_query(), self._where_values()) self._result_cache = [None]*self._cur.rowcount if self._cur.description: names = [i[0] for i in self._cur.description] @@ -368,7 +367,6 @@ class QuerySet(object): #return the connection to the connection pool if we have all objects if self._result_cache and self._result_idx == (len(self._result_cache) - 1): - self._con.close() self._con = None self._cur = None @@ -555,9 +553,8 @@ class QuerySet(object): qs = ' '.join(qs) - with connection_manager() as con: - cur = con.execute(qs, self._where_values()) - return cur.fetchone()[0] + cur = execute(qs, self._where_values()) + return cur.fetchone()[0] else: return len(self._result_cache) @@ -635,8 +632,7 @@ class QuerySet(object): if self._batch: self._batch.add_query(qs, self._where_values()) else: - with connection_manager() as con: - con.execute(qs, self._where_values()) + execute(qs, self._where_values()) def values_list(self, *fields, **kwargs): """ Instructs the query set to return tuples, not model instance """ @@ -753,8 +749,7 @@ class DMLQuery(object): if self.batch: self.batch.add_query(qs, query_values) else: - with connection_manager() as con: - con.execute(qs, query_values) + execute(qs, query_values) # delete nulled columns and removed map keys @@ -787,8 +782,7 @@ class DMLQuery(object): if self.batch: self.batch.add_query(qs, query_values) else: - with connection_manager() as con: - con.execute(qs, query_values) + execute(qs, query_values) def delete(self): """ Deletes one instance """ @@ -809,7 +803,6 @@ class DMLQuery(object): if self.batch: self.batch.add_query(qs, field_values) else: - with connection_manager() as con: - con.execute(qs, field_values) + execute(qs, field_values) diff --git a/cqlengine/tests/base.py b/cqlengine/tests/base.py index a7090c62..44008811 100644 --- a/cqlengine/tests/base.py +++ b/cqlengine/tests/base.py @@ -1,13 +1,13 @@ from unittest import TestCase -from cqlengine import connection +from cqlengine import connection class BaseCassEngTestCase(TestCase): @classmethod def setUpClass(cls): super(BaseCassEngTestCase, cls).setUpClass() - if not connection._connection_pool: - connection.setup(['localhost:9160'], default_keyspace='cqlengine_test') + # todo fix + connection.setup(['localhost:9160'], default_keyspace='cqlengine_test') def assertHasAttr(self, obj, attr): self.assertTrue(hasattr(obj, attr), diff --git a/cqlengine/tests/management/test_management.py b/cqlengine/tests/management/test_management.py index ac542bb8..be49199f 100644 --- a/cqlengine/tests/management/test_management.py +++ b/cqlengine/tests/management/test_management.py @@ -1,44 +1,51 @@ +from cqlengine.exceptions import CQLEngineException from cqlengine.management import create_table, delete_table from cqlengine.tests.base import BaseCassEngTestCase -from cqlengine.connection import ConnectionPool +from cqlengine.connection import ConnectionPool, Host -from mock import Mock +from mock import Mock, MagicMock, MagicProxy, patch from cqlengine import management from cqlengine.tests.query.test_queryset import TestModel +from cql.thrifteries import ThriftConnection -class ConnectionPoolTestCase(BaseCassEngTestCase): +class ConnectionPoolFailoverTestCase(BaseCassEngTestCase): """Test cassandra connection pooling.""" def setUp(self): - ConnectionPool.clear() + self.host = Host('127.0.0.1', '9160') + self.pool = ConnectionPool([self.host]) - def test_should_create_single_connection_on_request(self): - """Should create a single connection on first request""" - result = ConnectionPool.get() - self.assertIsNotNone(result) - self.assertEquals(0, ConnectionPool._queue.qsize()) - ConnectionPool._queue.put(result) - self.assertEquals(1, ConnectionPool._queue.qsize()) + def test_totally_dead_pool(self): + # kill the con + with patch('cqlengine.connection.cql.connect') as mock: + mock.side_effect=CQLEngineException + with self.assertRaises(CQLEngineException): + self.pool.execute("select * from system.peers", {}) - def test_should_close_connection_if_queue_is_full(self): - """Should close additional connections if queue is full""" - connections = [ConnectionPool.get() for x in range(10)] - for conn in connections: - ConnectionPool.put(conn) - fake_conn = Mock() - ConnectionPool.put(fake_conn) - fake_conn.close.assert_called_once_with() + def test_dead_node(self): + self.pool._hosts.append(self.host) + + # cursor mock needed so set_cql_version doesn't crap out + ok_cur = MagicMock() + + ok_conn = MagicMock() + ok_conn.return_value = ok_cur + + + returns = [CQLEngineException(), ok_conn] + + def side_effect(*args, **kwargs): + result = returns.pop(0) + if isinstance(result, Exception): + raise result + return result + + with patch('cqlengine.connection.cql.connect') as mock: + mock.side_effect = side_effect + conn = self.pool._create_connection() - def test_should_pop_connections_from_queue(self): - """Should pull existing connections off of the queue""" - conn = ConnectionPool.get() - ConnectionPool.put(conn) - self.assertEquals(1, ConnectionPool._queue.qsize()) - self.assertEquals(conn, ConnectionPool.get()) - self.assertEquals(0, ConnectionPool._queue.qsize()) - class CreateKeyspaceTest(BaseCassEngTestCase): def test_create_succeeeds(self): diff --git a/cqlengine/tests/query/test_queryset.py b/cqlengine/tests/query/test_queryset.py index 1c3eb95b..17954d18 100644 --- a/cqlengine/tests/query/test_queryset.py +++ b/cqlengine/tests/query/test_queryset.py @@ -405,18 +405,7 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage): assert q._con is None assert q._cur is None - def test_conn_is_returned_after_queryset_is_garbage_collected(self): - """ Tests that the connection is returned to the connection pool after the queryset is gc'd """ - from cqlengine.connection import ConnectionPool - # The queue size can be 1 if we just run this file's tests - # It will be 2 when we run 'em all - initial_size = ConnectionPool._queue.qsize() - q = TestModel.objects(test_id=0) - v = q[0] - assert ConnectionPool._queue.qsize() == initial_size - 1 - del q - assert ConnectionPool._queue.qsize() == initial_size class TimeUUIDQueryModel(Model): partition = columns.UUID(primary_key=True)