Getting connection pooling working with mocking library
Verified throwing exception when no servers are available, but correctly recovering and hitting the next server when one is. fixing minor pooling tests ensure we get the right exception back when no servers are available working on tests for retry connections working despite failure Removed old connection_manager and replaced with a simple context manager that allows for easy access to clients within the main pool
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -39,3 +39,6 @@ html/
|
||||
#Mr Developer
|
||||
.mr.developer.cfg
|
||||
.noseids
|
||||
/commitlog
|
||||
/data
|
||||
|
||||
|
@@ -9,8 +9,11 @@ import random
|
||||
import cql
|
||||
import logging
|
||||
|
||||
from copy import copy
|
||||
from cqlengine.exceptions import CQLEngineException
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from thrift.transport.TTransport import TTransportException
|
||||
|
||||
LOG = logging.getLogger('cqlengine.cql')
|
||||
@@ -20,7 +23,9 @@ class CQLConnectionError(CQLEngineException): pass
|
||||
Host = namedtuple('Host', ['name', 'port'])
|
||||
|
||||
_max_connections = 10
|
||||
_connection_pool = None
|
||||
|
||||
# global connection pool
|
||||
connection_pool = None
|
||||
|
||||
def setup(hosts, username=None, password=None, max_connections=10, default_keyspace=None):
|
||||
"""
|
||||
@@ -29,7 +34,7 @@ def setup(hosts, username=None, password=None, max_connections=10, default_keysp
|
||||
:param hosts: list of hosts, strings in the <hostname>:<port>, or just <hostname>
|
||||
"""
|
||||
global _max_connections
|
||||
global _connection_pool
|
||||
global connection_pool
|
||||
_max_connections = max_connections
|
||||
|
||||
if default_keyspace:
|
||||
@@ -50,16 +55,13 @@ def setup(hosts, username=None, password=None, max_connections=10, default_keysp
|
||||
if not _hosts:
|
||||
raise CQLConnectionError("At least one host required")
|
||||
|
||||
_connection_pool = ConnectionPool(_hosts)
|
||||
connection_pool = ConnectionPool(_hosts, username, password)
|
||||
|
||||
|
||||
class ConnectionPool(object):
|
||||
"""Handles pooling of database connections."""
|
||||
|
||||
# Connection pool queue
|
||||
_queue = None
|
||||
|
||||
def __init__(self, hosts, username, password):
|
||||
def __init__(self, hosts, username=None, password=None):
|
||||
self._hosts = hosts
|
||||
self._username = username
|
||||
self._password = password
|
||||
@@ -113,58 +115,40 @@ class ConnectionPool(object):
|
||||
if not self._hosts:
|
||||
raise CQLConnectionError("At least one host required")
|
||||
|
||||
host = _hosts[_host_idx]
|
||||
hosts = copy(self._hosts)
|
||||
random.shuffle(hosts)
|
||||
|
||||
new_conn = cql.connect(host.name, host.port, user=_username, password=_password)
|
||||
new_conn.set_cql_version('3.0.0')
|
||||
return new_conn
|
||||
|
||||
|
||||
class connection_manager(object):
|
||||
"""
|
||||
Connection failure tolerant connection manager. Written to be used in a 'with' block for connection pooling
|
||||
"""
|
||||
def __init__(self):
|
||||
if not _hosts:
|
||||
raise CQLConnectionError("No connections have been configured, call cqlengine.connection.setup")
|
||||
self.keyspace = None
|
||||
self.con = ConnectionPool.get()
|
||||
self.cur = None
|
||||
|
||||
def close(self):
|
||||
if self.cur: self.cur.close()
|
||||
ConnectionPool.put(self.con)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.close()
|
||||
|
||||
def execute(self, query, params={}):
|
||||
"""
|
||||
Gets a connection from the pool and executes the given query, returns the cursor
|
||||
|
||||
if there's a connection problem, this will silently create a new connection pool
|
||||
from the available hosts, and remove the problematic host from the host list
|
||||
"""
|
||||
global _host_idx
|
||||
|
||||
for i in range(len(_hosts)):
|
||||
for host in hosts:
|
||||
try:
|
||||
LOG.debug('{} {}'.format(query, repr(params)))
|
||||
self.cur = self.con.cursor()
|
||||
self.cur.execute(query, params)
|
||||
return self.cur
|
||||
except cql.ProgrammingError as ex:
|
||||
raise CQLEngineException(unicode(ex))
|
||||
except TTransportException:
|
||||
#TODO: check for other errors raised in the event of a connection / server problem
|
||||
#move to the next connection and set the connection pool
|
||||
_host_idx += 1
|
||||
_host_idx %= len(_hosts)
|
||||
self.con.close()
|
||||
self.con = ConnectionPool._create_connection()
|
||||
new_conn = cql.connect(host.name, host.port, user=self._username, password=self._password)
|
||||
new_conn.set_cql_version('3.0.0')
|
||||
return new_conn
|
||||
except Exception as e:
|
||||
logging.debug("Could not establish connection to {}:{}".format(host.name, host.port))
|
||||
pass
|
||||
|
||||
raise CQLConnectionError("couldn't reach a Cassandra server")
|
||||
raise CQLConnectionError("Could not connect to any server in cluster")
|
||||
|
||||
def execute(self, query, params):
|
||||
try:
|
||||
con = self.get()
|
||||
cur = con.cursor()
|
||||
cur.execute(query, params)
|
||||
self.put(con)
|
||||
return cur
|
||||
except cql.ProgrammingError as ex:
|
||||
raise CQLEngineException(unicode(ex))
|
||||
except TTransportException:
|
||||
pass
|
||||
|
||||
raise CQLEngineException("Could not execute query against the cluster")
|
||||
|
||||
def execute(query, params={}):
|
||||
return connection_pool.execute(query, params)
|
||||
|
||||
@contextmanager
|
||||
def connection_manager():
|
||||
global connection_pool
|
||||
tmp = connection_pool.get()
|
||||
yield tmp
|
||||
connection_pool.put(tmp)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
|
||||
from cqlengine.connection import connection_manager
|
||||
from cqlengine.connection import connection_manager, execute
|
||||
from cqlengine.exceptions import CQLEngineException
|
||||
|
||||
def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, durable_writes=True, **replication_values):
|
||||
@@ -15,11 +15,11 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3,
|
||||
"""
|
||||
with connection_manager() as con:
|
||||
#TODO: check system tables instead of using cql thrifteries
|
||||
if not any([name == k.name for k in con.con.client.describe_keyspaces()]):
|
||||
# if name not in [k.name for k in con.con.client.describe_keyspaces()]:
|
||||
if not any([name == k.name for k in con.client.describe_keyspaces()]):
|
||||
# if name not in [k.name for k in con.con.client.describe_keyspaces()]:
|
||||
try:
|
||||
#Try the 1.1 method
|
||||
con.execute("""CREATE KEYSPACE {}
|
||||
execute("""CREATE KEYSPACE {}
|
||||
WITH strategy_class = '{}'
|
||||
AND strategy_options:replication_factor={};""".format(name, strategy_class, replication_factor))
|
||||
except CQLEngineException:
|
||||
@@ -38,12 +38,12 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3,
|
||||
if strategy_class != 'SimpleStrategy':
|
||||
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
|
||||
|
||||
con.execute(query)
|
||||
execute(query)
|
||||
|
||||
def delete_keyspace(name):
|
||||
with connection_manager() as con:
|
||||
if name in [k.name for k in con.con.client.describe_keyspaces()]:
|
||||
con.execute("DROP KEYSPACE {}".format(name))
|
||||
if name in [k.name for k in con.client.describe_keyspaces()]:
|
||||
execute("DROP KEYSPACE {}".format(name))
|
||||
|
||||
def create_table(model, create_missing_keyspace=True):
|
||||
#construct query string
|
||||
@@ -55,78 +55,81 @@ def create_table(model, create_missing_keyspace=True):
|
||||
create_keyspace(model._get_keyspace())
|
||||
|
||||
with connection_manager() as con:
|
||||
#check for an existing column family
|
||||
#TODO: check system tables instead of using cql thrifteries
|
||||
ks_info = con.con.client.describe_keyspace(model._get_keyspace())
|
||||
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
|
||||
qs = ['CREATE TABLE {}'.format(cf_name)]
|
||||
ks_info = con.client.describe_keyspace(model._get_keyspace())
|
||||
|
||||
#add column types
|
||||
pkeys = []
|
||||
ckeys = []
|
||||
qtypes = []
|
||||
def add_column(col):
|
||||
s = col.get_column_def()
|
||||
if col.primary_key:
|
||||
keys = (pkeys if col.partition_key else ckeys)
|
||||
keys.append('"{}"'.format(col.db_field_name))
|
||||
qtypes.append(s)
|
||||
for name, col in model._columns.items():
|
||||
add_column(col)
|
||||
#check for an existing column family
|
||||
#TODO: check system tables instead of using cql thrifteries
|
||||
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
|
||||
qs = ['CREATE TABLE {}'.format(cf_name)]
|
||||
|
||||
qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or ''))
|
||||
|
||||
qs += ['({})'.format(', '.join(qtypes))]
|
||||
#add column types
|
||||
pkeys = []
|
||||
ckeys = []
|
||||
qtypes = []
|
||||
def add_column(col):
|
||||
s = col.get_column_def()
|
||||
if col.primary_key:
|
||||
keys = (pkeys if col.partition_key else ckeys)
|
||||
keys.append('"{}"'.format(col.db_field_name))
|
||||
qtypes.append(s)
|
||||
for name, col in model._columns.items():
|
||||
add_column(col)
|
||||
|
||||
with_qs = ['read_repair_chance = {}'.format(model.read_repair_chance)]
|
||||
qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or ''))
|
||||
|
||||
_order = ["%s %s" % (c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()]
|
||||
if _order:
|
||||
with_qs.append("clustering order by ({})".format(', '.join(_order)))
|
||||
qs += ['({})'.format(', '.join(qtypes))]
|
||||
|
||||
# add read_repair_chance
|
||||
qs += ['WITH {}'.format(' AND '.join(with_qs))]
|
||||
with_qs = ['read_repair_chance = {}'.format(model.read_repair_chance)]
|
||||
|
||||
_order = ["%s %s" % (c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()]
|
||||
if _order:
|
||||
with_qs.append("clustering order by ({})".format(', '.join(_order)))
|
||||
|
||||
# add read_repair_chance
|
||||
qs += ['WITH {}'.format(' AND '.join(with_qs))]
|
||||
qs = ' '.join(qs)
|
||||
|
||||
try:
|
||||
execute(qs)
|
||||
except CQLEngineException as ex:
|
||||
# 1.2 doesn't return cf names, so we have to examine the exception
|
||||
# and ignore if it says the column family already exists
|
||||
if "Cannot add already existing column family" not in unicode(ex):
|
||||
raise
|
||||
|
||||
#get existing index names, skip ones that already exist
|
||||
with connection_manager() as con:
|
||||
ks_info = con.client.describe_keyspace(model._get_keyspace())
|
||||
|
||||
cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name]
|
||||
idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else []
|
||||
idx_names = filter(None, idx_names)
|
||||
|
||||
indexes = [c for n,c in model._columns.items() if c.index]
|
||||
if indexes:
|
||||
for column in indexes:
|
||||
if column.db_index_name in idx_names: continue
|
||||
qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)]
|
||||
qs += ['ON {}'.format(cf_name)]
|
||||
qs += ['("{}")'.format(column.db_field_name)]
|
||||
qs = ' '.join(qs)
|
||||
|
||||
try:
|
||||
con.execute(qs)
|
||||
execute(qs)
|
||||
except CQLEngineException as ex:
|
||||
# 1.2 doesn't return cf names, so we have to examine the exception
|
||||
# and ignore if it says the column family already exists
|
||||
if "Cannot add already existing column family" not in unicode(ex):
|
||||
# and ignore if it says the index already exists
|
||||
if "Index already exists" not in unicode(ex):
|
||||
raise
|
||||
|
||||
#get existing index names, skip ones that already exist
|
||||
ks_info = con.con.client.describe_keyspace(model._get_keyspace())
|
||||
cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name]
|
||||
idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else []
|
||||
idx_names = filter(None, idx_names)
|
||||
|
||||
indexes = [c for n,c in model._columns.items() if c.index]
|
||||
if indexes:
|
||||
for column in indexes:
|
||||
if column.db_index_name in idx_names: continue
|
||||
qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)]
|
||||
qs += ['ON {}'.format(cf_name)]
|
||||
qs += ['("{}")'.format(column.db_field_name)]
|
||||
qs = ' '.join(qs)
|
||||
|
||||
try:
|
||||
con.execute(qs)
|
||||
except CQLEngineException as ex:
|
||||
# 1.2 doesn't return cf names, so we have to examine the exception
|
||||
# and ignore if it says the index already exists
|
||||
if "Index already exists" not in unicode(ex):
|
||||
raise
|
||||
|
||||
|
||||
def delete_table(model):
|
||||
cf_name = model.column_family_name()
|
||||
with connection_manager() as con:
|
||||
try:
|
||||
con.execute('drop table {};'.format(cf_name))
|
||||
except CQLEngineException as ex:
|
||||
#don't freak out if the table doesn't exist
|
||||
if 'Cannot drop non existing column family' not in unicode(ex):
|
||||
raise
|
||||
|
||||
try:
|
||||
execute('drop table {};'.format(cf_name))
|
||||
except CQLEngineException as ex:
|
||||
#don't freak out if the table doesn't exist
|
||||
if 'Cannot drop non existing column family' not in unicode(ex):
|
||||
raise
|
||||
|
||||
|
@@ -6,7 +6,8 @@ from time import time
|
||||
from uuid import uuid1
|
||||
from cqlengine import BaseContainerColumn, BaseValueManager, Map, columns
|
||||
|
||||
from cqlengine.connection import connection_manager
|
||||
from cqlengine.connection import connection_pool, connection_manager, execute
|
||||
|
||||
from cqlengine.exceptions import CQLEngineException
|
||||
from cqlengine.functions import QueryValue, Token
|
||||
|
||||
@@ -193,8 +194,7 @@ class BatchQuery(object):
|
||||
|
||||
query_list.append('APPLY BATCH;')
|
||||
|
||||
with connection_manager() as con:
|
||||
con.execute('\n'.join(query_list), parameters)
|
||||
execute('\n'.join(query_list), parameters)
|
||||
|
||||
self.queries = []
|
||||
|
||||
@@ -346,8 +346,7 @@ class QuerySet(object):
|
||||
if self._batch:
|
||||
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
|
||||
if self._result_cache is None:
|
||||
self._con = connection_manager()
|
||||
self._cur = self._con.execute(self._select_query(), self._where_values())
|
||||
self._cur = execute(self._select_query(), self._where_values())
|
||||
self._result_cache = [None]*self._cur.rowcount
|
||||
if self._cur.description:
|
||||
names = [i[0] for i in self._cur.description]
|
||||
@@ -368,7 +367,6 @@ class QuerySet(object):
|
||||
|
||||
#return the connection to the connection pool if we have all objects
|
||||
if self._result_cache and self._result_idx == (len(self._result_cache) - 1):
|
||||
self._con.close()
|
||||
self._con = None
|
||||
self._cur = None
|
||||
|
||||
@@ -555,9 +553,8 @@ class QuerySet(object):
|
||||
|
||||
qs = ' '.join(qs)
|
||||
|
||||
with connection_manager() as con:
|
||||
cur = con.execute(qs, self._where_values())
|
||||
return cur.fetchone()[0]
|
||||
cur = execute(qs, self._where_values())
|
||||
return cur.fetchone()[0]
|
||||
else:
|
||||
return len(self._result_cache)
|
||||
|
||||
@@ -635,8 +632,7 @@ class QuerySet(object):
|
||||
if self._batch:
|
||||
self._batch.add_query(qs, self._where_values())
|
||||
else:
|
||||
with connection_manager() as con:
|
||||
con.execute(qs, self._where_values())
|
||||
execute(qs, self._where_values())
|
||||
|
||||
def values_list(self, *fields, **kwargs):
|
||||
""" Instructs the query set to return tuples, not model instance """
|
||||
@@ -753,8 +749,7 @@ class DMLQuery(object):
|
||||
if self.batch:
|
||||
self.batch.add_query(qs, query_values)
|
||||
else:
|
||||
with connection_manager() as con:
|
||||
con.execute(qs, query_values)
|
||||
execute(qs, query_values)
|
||||
|
||||
|
||||
# delete nulled columns and removed map keys
|
||||
@@ -787,8 +782,7 @@ class DMLQuery(object):
|
||||
if self.batch:
|
||||
self.batch.add_query(qs, query_values)
|
||||
else:
|
||||
with connection_manager() as con:
|
||||
con.execute(qs, query_values)
|
||||
execute(qs, query_values)
|
||||
|
||||
def delete(self):
|
||||
""" Deletes one instance """
|
||||
@@ -809,7 +803,6 @@ class DMLQuery(object):
|
||||
if self.batch:
|
||||
self.batch.add_query(qs, field_values)
|
||||
else:
|
||||
with connection_manager() as con:
|
||||
con.execute(qs, field_values)
|
||||
execute(qs, field_values)
|
||||
|
||||
|
||||
|
@@ -1,13 +1,13 @@
|
||||
from unittest import TestCase
|
||||
from cqlengine import connection
|
||||
from cqlengine import connection
|
||||
|
||||
class BaseCassEngTestCase(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(BaseCassEngTestCase, cls).setUpClass()
|
||||
if not connection._connection_pool:
|
||||
connection.setup(['localhost:9160'], default_keyspace='cqlengine_test')
|
||||
# todo fix
|
||||
connection.setup(['localhost:9160'], default_keyspace='cqlengine_test')
|
||||
|
||||
def assertHasAttr(self, obj, attr):
|
||||
self.assertTrue(hasattr(obj, attr),
|
||||
|
@@ -1,44 +1,51 @@
|
||||
from cqlengine.exceptions import CQLEngineException
|
||||
from cqlengine.management import create_table, delete_table
|
||||
from cqlengine.tests.base import BaseCassEngTestCase
|
||||
|
||||
from cqlengine.connection import ConnectionPool
|
||||
from cqlengine.connection import ConnectionPool, Host
|
||||
|
||||
from mock import Mock
|
||||
from mock import Mock, MagicMock, MagicProxy, patch
|
||||
from cqlengine import management
|
||||
from cqlengine.tests.query.test_queryset import TestModel
|
||||
|
||||
from cql.thrifteries import ThriftConnection
|
||||
|
||||
class ConnectionPoolTestCase(BaseCassEngTestCase):
|
||||
class ConnectionPoolFailoverTestCase(BaseCassEngTestCase):
|
||||
"""Test cassandra connection pooling."""
|
||||
|
||||
def setUp(self):
|
||||
ConnectionPool.clear()
|
||||
self.host = Host('127.0.0.1', '9160')
|
||||
self.pool = ConnectionPool([self.host])
|
||||
|
||||
def test_should_create_single_connection_on_request(self):
|
||||
"""Should create a single connection on first request"""
|
||||
result = ConnectionPool.get()
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEquals(0, ConnectionPool._queue.qsize())
|
||||
ConnectionPool._queue.put(result)
|
||||
self.assertEquals(1, ConnectionPool._queue.qsize())
|
||||
def test_totally_dead_pool(self):
|
||||
# kill the con
|
||||
with patch('cqlengine.connection.cql.connect') as mock:
|
||||
mock.side_effect=CQLEngineException
|
||||
with self.assertRaises(CQLEngineException):
|
||||
self.pool.execute("select * from system.peers", {})
|
||||
|
||||
def test_should_close_connection_if_queue_is_full(self):
|
||||
"""Should close additional connections if queue is full"""
|
||||
connections = [ConnectionPool.get() for x in range(10)]
|
||||
for conn in connections:
|
||||
ConnectionPool.put(conn)
|
||||
fake_conn = Mock()
|
||||
ConnectionPool.put(fake_conn)
|
||||
fake_conn.close.assert_called_once_with()
|
||||
def test_dead_node(self):
|
||||
self.pool._hosts.append(self.host)
|
||||
|
||||
# cursor mock needed so set_cql_version doesn't crap out
|
||||
ok_cur = MagicMock()
|
||||
|
||||
ok_conn = MagicMock()
|
||||
ok_conn.return_value = ok_cur
|
||||
|
||||
|
||||
returns = [CQLEngineException(), ok_conn]
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
result = returns.pop(0)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
with patch('cqlengine.connection.cql.connect') as mock:
|
||||
mock.side_effect = side_effect
|
||||
conn = self.pool._create_connection()
|
||||
|
||||
def test_should_pop_connections_from_queue(self):
|
||||
"""Should pull existing connections off of the queue"""
|
||||
conn = ConnectionPool.get()
|
||||
ConnectionPool.put(conn)
|
||||
self.assertEquals(1, ConnectionPool._queue.qsize())
|
||||
self.assertEquals(conn, ConnectionPool.get())
|
||||
self.assertEquals(0, ConnectionPool._queue.qsize())
|
||||
|
||||
|
||||
class CreateKeyspaceTest(BaseCassEngTestCase):
|
||||
def test_create_succeeeds(self):
|
||||
|
@@ -405,18 +405,7 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage):
|
||||
assert q._con is None
|
||||
assert q._cur is None
|
||||
|
||||
def test_conn_is_returned_after_queryset_is_garbage_collected(self):
|
||||
""" Tests that the connection is returned to the connection pool after the queryset is gc'd """
|
||||
from cqlengine.connection import ConnectionPool
|
||||
# The queue size can be 1 if we just run this file's tests
|
||||
# It will be 2 when we run 'em all
|
||||
initial_size = ConnectionPool._queue.qsize()
|
||||
q = TestModel.objects(test_id=0)
|
||||
v = q[0]
|
||||
assert ConnectionPool._queue.qsize() == initial_size - 1
|
||||
|
||||
del q
|
||||
assert ConnectionPool._queue.qsize() == initial_size
|
||||
|
||||
class TimeUUIDQueryModel(Model):
|
||||
partition = columns.UUID(primary_key=True)
|
||||
|
Reference in New Issue
Block a user