Getting connection pooling working with mocking library

Verified throwing exception when no servers are available,
but correctly recovering and hitting the next server when
one is.

fixing minor pooling tests

ensure we get the right exception back when no servers are available

working on tests for retry

connections working despite failure

Removed old connection_manager and replaced with a simple context
manager that allows for easy access to clients within the main pool
This commit is contained in:
Jon Haddad
2013-06-03 14:48:44 -07:00
parent e6a7934fe3
commit 37259b963e
7 changed files with 161 additions and 182 deletions

3
.gitignore vendored
View File

@@ -39,3 +39,6 @@ html/
#Mr Developer
.mr.developer.cfg
.noseids
/commitlog
/data

View File

@@ -9,8 +9,11 @@ import random
import cql
import logging
from copy import copy
from cqlengine.exceptions import CQLEngineException
from contextlib import contextmanager
from thrift.transport.TTransport import TTransportException
LOG = logging.getLogger('cqlengine.cql')
@@ -20,7 +23,9 @@ class CQLConnectionError(CQLEngineException): pass
Host = namedtuple('Host', ['name', 'port'])
_max_connections = 10
_connection_pool = None
# global connection pool
connection_pool = None
def setup(hosts, username=None, password=None, max_connections=10, default_keyspace=None):
"""
@@ -29,7 +34,7 @@ def setup(hosts, username=None, password=None, max_connections=10, default_keysp
:param hosts: list of hosts, strings in the <hostname>:<port>, or just <hostname>
"""
global _max_connections
global _connection_pool
global connection_pool
_max_connections = max_connections
if default_keyspace:
@@ -50,16 +55,13 @@ def setup(hosts, username=None, password=None, max_connections=10, default_keysp
if not _hosts:
raise CQLConnectionError("At least one host required")
_connection_pool = ConnectionPool(_hosts)
connection_pool = ConnectionPool(_hosts, username, password)
class ConnectionPool(object):
"""Handles pooling of database connections."""
# Connection pool queue
_queue = None
def __init__(self, hosts, username, password):
def __init__(self, hosts, username=None, password=None):
self._hosts = hosts
self._username = username
self._password = password
@@ -113,58 +115,40 @@ class ConnectionPool(object):
if not self._hosts:
raise CQLConnectionError("At least one host required")
host = _hosts[_host_idx]
hosts = copy(self._hosts)
random.shuffle(hosts)
new_conn = cql.connect(host.name, host.port, user=_username, password=_password)
new_conn.set_cql_version('3.0.0')
return new_conn
class connection_manager(object):
"""
Connection failure tolerant connection manager. Written to be used in a 'with' block for connection pooling
"""
def __init__(self):
if not _hosts:
raise CQLConnectionError("No connections have been configured, call cqlengine.connection.setup")
self.keyspace = None
self.con = ConnectionPool.get()
self.cur = None
def close(self):
if self.cur: self.cur.close()
ConnectionPool.put(self.con)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
def execute(self, query, params={}):
"""
Gets a connection from the pool and executes the given query, returns the cursor
if there's a connection problem, this will silently create a new connection pool
from the available hosts, and remove the problematic host from the host list
"""
global _host_idx
for i in range(len(_hosts)):
for host in hosts:
try:
LOG.debug('{} {}'.format(query, repr(params)))
self.cur = self.con.cursor()
self.cur.execute(query, params)
return self.cur
except cql.ProgrammingError as ex:
raise CQLEngineException(unicode(ex))
except TTransportException:
#TODO: check for other errors raised in the event of a connection / server problem
#move to the next connection and set the connection pool
_host_idx += 1
_host_idx %= len(_hosts)
self.con.close()
self.con = ConnectionPool._create_connection()
new_conn = cql.connect(host.name, host.port, user=self._username, password=self._password)
new_conn.set_cql_version('3.0.0')
return new_conn
except Exception as e:
logging.debug("Could not establish connection to {}:{}".format(host.name, host.port))
pass
raise CQLConnectionError("couldn't reach a Cassandra server")
raise CQLConnectionError("Could not connect to any server in cluster")
def execute(self, query, params):
try:
con = self.get()
cur = con.cursor()
cur.execute(query, params)
self.put(con)
return cur
except cql.ProgrammingError as ex:
raise CQLEngineException(unicode(ex))
except TTransportException:
pass
raise CQLEngineException("Could not execute query against the cluster")
def execute(query, params={}):
return connection_pool.execute(query, params)
@contextmanager
def connection_manager():
global connection_pool
tmp = connection_pool.get()
yield tmp
connection_pool.put(tmp)

View File

@@ -1,6 +1,6 @@
import json
from cqlengine.connection import connection_manager
from cqlengine.connection import connection_manager, execute
from cqlengine.exceptions import CQLEngineException
def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, durable_writes=True, **replication_values):
@@ -15,11 +15,11 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3,
"""
with connection_manager() as con:
#TODO: check system tables instead of using cql thrifteries
if not any([name == k.name for k in con.con.client.describe_keyspaces()]):
# if name not in [k.name for k in con.con.client.describe_keyspaces()]:
if not any([name == k.name for k in con.client.describe_keyspaces()]):
# if name not in [k.name for k in con.con.client.describe_keyspaces()]:
try:
#Try the 1.1 method
con.execute("""CREATE KEYSPACE {}
execute("""CREATE KEYSPACE {}
WITH strategy_class = '{}'
AND strategy_options:replication_factor={};""".format(name, strategy_class, replication_factor))
except CQLEngineException:
@@ -38,12 +38,12 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3,
if strategy_class != 'SimpleStrategy':
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
con.execute(query)
execute(query)
def delete_keyspace(name):
with connection_manager() as con:
if name in [k.name for k in con.con.client.describe_keyspaces()]:
con.execute("DROP KEYSPACE {}".format(name))
if name in [k.name for k in con.client.describe_keyspaces()]:
execute("DROP KEYSPACE {}".format(name))
def create_table(model, create_missing_keyspace=True):
#construct query string
@@ -55,78 +55,81 @@ def create_table(model, create_missing_keyspace=True):
create_keyspace(model._get_keyspace())
with connection_manager() as con:
#check for an existing column family
#TODO: check system tables instead of using cql thrifteries
ks_info = con.con.client.describe_keyspace(model._get_keyspace())
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
qs = ['CREATE TABLE {}'.format(cf_name)]
ks_info = con.client.describe_keyspace(model._get_keyspace())
#add column types
pkeys = []
ckeys = []
qtypes = []
def add_column(col):
s = col.get_column_def()
if col.primary_key:
keys = (pkeys if col.partition_key else ckeys)
keys.append('"{}"'.format(col.db_field_name))
qtypes.append(s)
for name, col in model._columns.items():
add_column(col)
#check for an existing column family
#TODO: check system tables instead of using cql thrifteries
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
qs = ['CREATE TABLE {}'.format(cf_name)]
qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or ''))
qs += ['({})'.format(', '.join(qtypes))]
#add column types
pkeys = []
ckeys = []
qtypes = []
def add_column(col):
s = col.get_column_def()
if col.primary_key:
keys = (pkeys if col.partition_key else ckeys)
keys.append('"{}"'.format(col.db_field_name))
qtypes.append(s)
for name, col in model._columns.items():
add_column(col)
with_qs = ['read_repair_chance = {}'.format(model.read_repair_chance)]
qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or ''))
_order = ["%s %s" % (c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()]
if _order:
with_qs.append("clustering order by ({})".format(', '.join(_order)))
qs += ['({})'.format(', '.join(qtypes))]
# add read_repair_chance
qs += ['WITH {}'.format(' AND '.join(with_qs))]
with_qs = ['read_repair_chance = {}'.format(model.read_repair_chance)]
_order = ["%s %s" % (c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()]
if _order:
with_qs.append("clustering order by ({})".format(', '.join(_order)))
# add read_repair_chance
qs += ['WITH {}'.format(' AND '.join(with_qs))]
qs = ' '.join(qs)
try:
execute(qs)
except CQLEngineException as ex:
# 1.2 doesn't return cf names, so we have to examine the exception
# and ignore if it says the column family already exists
if "Cannot add already existing column family" not in unicode(ex):
raise
#get existing index names, skip ones that already exist
with connection_manager() as con:
ks_info = con.client.describe_keyspace(model._get_keyspace())
cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name]
idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else []
idx_names = filter(None, idx_names)
indexes = [c for n,c in model._columns.items() if c.index]
if indexes:
for column in indexes:
if column.db_index_name in idx_names: continue
qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)]
qs += ['ON {}'.format(cf_name)]
qs += ['("{}")'.format(column.db_field_name)]
qs = ' '.join(qs)
try:
con.execute(qs)
execute(qs)
except CQLEngineException as ex:
# 1.2 doesn't return cf names, so we have to examine the exception
# and ignore if it says the column family already exists
if "Cannot add already existing column family" not in unicode(ex):
# and ignore if it says the index already exists
if "Index already exists" not in unicode(ex):
raise
#get existing index names, skip ones that already exist
ks_info = con.con.client.describe_keyspace(model._get_keyspace())
cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name]
idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else []
idx_names = filter(None, idx_names)
indexes = [c for n,c in model._columns.items() if c.index]
if indexes:
for column in indexes:
if column.db_index_name in idx_names: continue
qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)]
qs += ['ON {}'.format(cf_name)]
qs += ['("{}")'.format(column.db_field_name)]
qs = ' '.join(qs)
try:
con.execute(qs)
except CQLEngineException as ex:
# 1.2 doesn't return cf names, so we have to examine the exception
# and ignore if it says the index already exists
if "Index already exists" not in unicode(ex):
raise
def delete_table(model):
cf_name = model.column_family_name()
with connection_manager() as con:
try:
con.execute('drop table {};'.format(cf_name))
except CQLEngineException as ex:
#don't freak out if the table doesn't exist
if 'Cannot drop non existing column family' not in unicode(ex):
raise
try:
execute('drop table {};'.format(cf_name))
except CQLEngineException as ex:
#don't freak out if the table doesn't exist
if 'Cannot drop non existing column family' not in unicode(ex):
raise

View File

@@ -6,7 +6,8 @@ from time import time
from uuid import uuid1
from cqlengine import BaseContainerColumn, BaseValueManager, Map, columns
from cqlengine.connection import connection_manager
from cqlengine.connection import connection_pool, connection_manager, execute
from cqlengine.exceptions import CQLEngineException
from cqlengine.functions import QueryValue, Token
@@ -193,8 +194,7 @@ class BatchQuery(object):
query_list.append('APPLY BATCH;')
with connection_manager() as con:
con.execute('\n'.join(query_list), parameters)
execute('\n'.join(query_list), parameters)
self.queries = []
@@ -346,8 +346,7 @@ class QuerySet(object):
if self._batch:
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
if self._result_cache is None:
self._con = connection_manager()
self._cur = self._con.execute(self._select_query(), self._where_values())
self._cur = execute(self._select_query(), self._where_values())
self._result_cache = [None]*self._cur.rowcount
if self._cur.description:
names = [i[0] for i in self._cur.description]
@@ -368,7 +367,6 @@ class QuerySet(object):
#return the connection to the connection pool if we have all objects
if self._result_cache and self._result_idx == (len(self._result_cache) - 1):
self._con.close()
self._con = None
self._cur = None
@@ -555,9 +553,8 @@ class QuerySet(object):
qs = ' '.join(qs)
with connection_manager() as con:
cur = con.execute(qs, self._where_values())
return cur.fetchone()[0]
cur = execute(qs, self._where_values())
return cur.fetchone()[0]
else:
return len(self._result_cache)
@@ -635,8 +632,7 @@ class QuerySet(object):
if self._batch:
self._batch.add_query(qs, self._where_values())
else:
with connection_manager() as con:
con.execute(qs, self._where_values())
execute(qs, self._where_values())
def values_list(self, *fields, **kwargs):
""" Instructs the query set to return tuples, not model instance """
@@ -753,8 +749,7 @@ class DMLQuery(object):
if self.batch:
self.batch.add_query(qs, query_values)
else:
with connection_manager() as con:
con.execute(qs, query_values)
execute(qs, query_values)
# delete nulled columns and removed map keys
@@ -787,8 +782,7 @@ class DMLQuery(object):
if self.batch:
self.batch.add_query(qs, query_values)
else:
with connection_manager() as con:
con.execute(qs, query_values)
execute(qs, query_values)
def delete(self):
""" Deletes one instance """
@@ -809,7 +803,6 @@ class DMLQuery(object):
if self.batch:
self.batch.add_query(qs, field_values)
else:
with connection_manager() as con:
con.execute(qs, field_values)
execute(qs, field_values)

View File

@@ -1,13 +1,13 @@
from unittest import TestCase
from cqlengine import connection
from cqlengine import connection
class BaseCassEngTestCase(TestCase):
@classmethod
def setUpClass(cls):
super(BaseCassEngTestCase, cls).setUpClass()
if not connection._connection_pool:
connection.setup(['localhost:9160'], default_keyspace='cqlengine_test')
# todo fix
connection.setup(['localhost:9160'], default_keyspace='cqlengine_test')
def assertHasAttr(self, obj, attr):
self.assertTrue(hasattr(obj, attr),

View File

@@ -1,44 +1,51 @@
from cqlengine.exceptions import CQLEngineException
from cqlengine.management import create_table, delete_table
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.connection import ConnectionPool
from cqlengine.connection import ConnectionPool, Host
from mock import Mock
from mock import Mock, MagicMock, MagicProxy, patch
from cqlengine import management
from cqlengine.tests.query.test_queryset import TestModel
from cql.thrifteries import ThriftConnection
class ConnectionPoolTestCase(BaseCassEngTestCase):
class ConnectionPoolFailoverTestCase(BaseCassEngTestCase):
"""Test cassandra connection pooling."""
def setUp(self):
ConnectionPool.clear()
self.host = Host('127.0.0.1', '9160')
self.pool = ConnectionPool([self.host])
def test_should_create_single_connection_on_request(self):
"""Should create a single connection on first request"""
result = ConnectionPool.get()
self.assertIsNotNone(result)
self.assertEquals(0, ConnectionPool._queue.qsize())
ConnectionPool._queue.put(result)
self.assertEquals(1, ConnectionPool._queue.qsize())
def test_totally_dead_pool(self):
# kill the con
with patch('cqlengine.connection.cql.connect') as mock:
mock.side_effect=CQLEngineException
with self.assertRaises(CQLEngineException):
self.pool.execute("select * from system.peers", {})
def test_should_close_connection_if_queue_is_full(self):
"""Should close additional connections if queue is full"""
connections = [ConnectionPool.get() for x in range(10)]
for conn in connections:
ConnectionPool.put(conn)
fake_conn = Mock()
ConnectionPool.put(fake_conn)
fake_conn.close.assert_called_once_with()
def test_dead_node(self):
self.pool._hosts.append(self.host)
# cursor mock needed so set_cql_version doesn't crap out
ok_cur = MagicMock()
ok_conn = MagicMock()
ok_conn.return_value = ok_cur
returns = [CQLEngineException(), ok_conn]
def side_effect(*args, **kwargs):
result = returns.pop(0)
if isinstance(result, Exception):
raise result
return result
with patch('cqlengine.connection.cql.connect') as mock:
mock.side_effect = side_effect
conn = self.pool._create_connection()
def test_should_pop_connections_from_queue(self):
"""Should pull existing connections off of the queue"""
conn = ConnectionPool.get()
ConnectionPool.put(conn)
self.assertEquals(1, ConnectionPool._queue.qsize())
self.assertEquals(conn, ConnectionPool.get())
self.assertEquals(0, ConnectionPool._queue.qsize())
class CreateKeyspaceTest(BaseCassEngTestCase):
def test_create_succeeeds(self):

View File

@@ -405,18 +405,7 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage):
assert q._con is None
assert q._cur is None
def test_conn_is_returned_after_queryset_is_garbage_collected(self):
""" Tests that the connection is returned to the connection pool after the queryset is gc'd """
from cqlengine.connection import ConnectionPool
# The queue size can be 1 if we just run this file's tests
# It will be 2 when we run 'em all
initial_size = ConnectionPool._queue.qsize()
q = TestModel.objects(test_id=0)
v = q[0]
assert ConnectionPool._queue.qsize() == initial_size - 1
del q
assert ConnectionPool._queue.qsize() == initial_size
class TimeUUIDQueryModel(Model):
partition = columns.UUID(primary_key=True)