Getting connection pooling working with mocking library
Verified throwing exception when no servers are available, but correctly recovering and hitting the next server when one is. fixing minor pooling tests ensure we get the right exception back when no servers are available working on tests for retry connections working despite failure Removed old connection_manager and replaced with a simple context manager that allows for easy access to clients within the main pool
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -39,3 +39,6 @@ html/
|
|||||||
#Mr Developer
|
#Mr Developer
|
||||||
.mr.developer.cfg
|
.mr.developer.cfg
|
||||||
.noseids
|
.noseids
|
||||||
|
/commitlog
|
||||||
|
/data
|
||||||
|
|
||||||
|
@@ -9,8 +9,11 @@ import random
|
|||||||
import cql
|
import cql
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from copy import copy
|
||||||
from cqlengine.exceptions import CQLEngineException
|
from cqlengine.exceptions import CQLEngineException
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
from thrift.transport.TTransport import TTransportException
|
from thrift.transport.TTransport import TTransportException
|
||||||
|
|
||||||
LOG = logging.getLogger('cqlengine.cql')
|
LOG = logging.getLogger('cqlengine.cql')
|
||||||
@@ -20,7 +23,9 @@ class CQLConnectionError(CQLEngineException): pass
|
|||||||
Host = namedtuple('Host', ['name', 'port'])
|
Host = namedtuple('Host', ['name', 'port'])
|
||||||
|
|
||||||
_max_connections = 10
|
_max_connections = 10
|
||||||
_connection_pool = None
|
|
||||||
|
# global connection pool
|
||||||
|
connection_pool = None
|
||||||
|
|
||||||
def setup(hosts, username=None, password=None, max_connections=10, default_keyspace=None):
|
def setup(hosts, username=None, password=None, max_connections=10, default_keyspace=None):
|
||||||
"""
|
"""
|
||||||
@@ -29,7 +34,7 @@ def setup(hosts, username=None, password=None, max_connections=10, default_keysp
|
|||||||
:param hosts: list of hosts, strings in the <hostname>:<port>, or just <hostname>
|
:param hosts: list of hosts, strings in the <hostname>:<port>, or just <hostname>
|
||||||
"""
|
"""
|
||||||
global _max_connections
|
global _max_connections
|
||||||
global _connection_pool
|
global connection_pool
|
||||||
_max_connections = max_connections
|
_max_connections = max_connections
|
||||||
|
|
||||||
if default_keyspace:
|
if default_keyspace:
|
||||||
@@ -50,16 +55,13 @@ def setup(hosts, username=None, password=None, max_connections=10, default_keysp
|
|||||||
if not _hosts:
|
if not _hosts:
|
||||||
raise CQLConnectionError("At least one host required")
|
raise CQLConnectionError("At least one host required")
|
||||||
|
|
||||||
_connection_pool = ConnectionPool(_hosts)
|
connection_pool = ConnectionPool(_hosts, username, password)
|
||||||
|
|
||||||
|
|
||||||
class ConnectionPool(object):
|
class ConnectionPool(object):
|
||||||
"""Handles pooling of database connections."""
|
"""Handles pooling of database connections."""
|
||||||
|
|
||||||
# Connection pool queue
|
def __init__(self, hosts, username=None, password=None):
|
||||||
_queue = None
|
|
||||||
|
|
||||||
def __init__(self, hosts, username, password):
|
|
||||||
self._hosts = hosts
|
self._hosts = hosts
|
||||||
self._username = username
|
self._username = username
|
||||||
self._password = password
|
self._password = password
|
||||||
@@ -113,58 +115,40 @@ class ConnectionPool(object):
|
|||||||
if not self._hosts:
|
if not self._hosts:
|
||||||
raise CQLConnectionError("At least one host required")
|
raise CQLConnectionError("At least one host required")
|
||||||
|
|
||||||
host = _hosts[_host_idx]
|
hosts = copy(self._hosts)
|
||||||
|
random.shuffle(hosts)
|
||||||
|
|
||||||
new_conn = cql.connect(host.name, host.port, user=_username, password=_password)
|
for host in hosts:
|
||||||
new_conn.set_cql_version('3.0.0')
|
|
||||||
return new_conn
|
|
||||||
|
|
||||||
|
|
||||||
class connection_manager(object):
|
|
||||||
"""
|
|
||||||
Connection failure tolerant connection manager. Written to be used in a 'with' block for connection pooling
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
if not _hosts:
|
|
||||||
raise CQLConnectionError("No connections have been configured, call cqlengine.connection.setup")
|
|
||||||
self.keyspace = None
|
|
||||||
self.con = ConnectionPool.get()
|
|
||||||
self.cur = None
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
if self.cur: self.cur.close()
|
|
||||||
ConnectionPool.put(self.con)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def execute(self, query, params={}):
|
|
||||||
"""
|
|
||||||
Gets a connection from the pool and executes the given query, returns the cursor
|
|
||||||
|
|
||||||
if there's a connection problem, this will silently create a new connection pool
|
|
||||||
from the available hosts, and remove the problematic host from the host list
|
|
||||||
"""
|
|
||||||
global _host_idx
|
|
||||||
|
|
||||||
for i in range(len(_hosts)):
|
|
||||||
try:
|
try:
|
||||||
LOG.debug('{} {}'.format(query, repr(params)))
|
new_conn = cql.connect(host.name, host.port, user=self._username, password=self._password)
|
||||||
self.cur = self.con.cursor()
|
new_conn.set_cql_version('3.0.0')
|
||||||
self.cur.execute(query, params)
|
return new_conn
|
||||||
return self.cur
|
except Exception as e:
|
||||||
except cql.ProgrammingError as ex:
|
logging.debug("Could not establish connection to {}:{}".format(host.name, host.port))
|
||||||
raise CQLEngineException(unicode(ex))
|
pass
|
||||||
except TTransportException:
|
|
||||||
#TODO: check for other errors raised in the event of a connection / server problem
|
|
||||||
#move to the next connection and set the connection pool
|
|
||||||
_host_idx += 1
|
|
||||||
_host_idx %= len(_hosts)
|
|
||||||
self.con.close()
|
|
||||||
self.con = ConnectionPool._create_connection()
|
|
||||||
|
|
||||||
raise CQLConnectionError("couldn't reach a Cassandra server")
|
raise CQLConnectionError("Could not connect to any server in cluster")
|
||||||
|
|
||||||
|
def execute(self, query, params):
|
||||||
|
try:
|
||||||
|
con = self.get()
|
||||||
|
cur = con.cursor()
|
||||||
|
cur.execute(query, params)
|
||||||
|
self.put(con)
|
||||||
|
return cur
|
||||||
|
except cql.ProgrammingError as ex:
|
||||||
|
raise CQLEngineException(unicode(ex))
|
||||||
|
except TTransportException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
raise CQLEngineException("Could not execute query against the cluster")
|
||||||
|
|
||||||
|
def execute(query, params={}):
|
||||||
|
return connection_pool.execute(query, params)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def connection_manager():
|
||||||
|
global connection_pool
|
||||||
|
tmp = connection_pool.get()
|
||||||
|
yield tmp
|
||||||
|
connection_pool.put(tmp)
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from cqlengine.connection import connection_manager
|
from cqlengine.connection import connection_manager, execute
|
||||||
from cqlengine.exceptions import CQLEngineException
|
from cqlengine.exceptions import CQLEngineException
|
||||||
|
|
||||||
def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, durable_writes=True, **replication_values):
|
def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, durable_writes=True, **replication_values):
|
||||||
@@ -15,11 +15,11 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3,
|
|||||||
"""
|
"""
|
||||||
with connection_manager() as con:
|
with connection_manager() as con:
|
||||||
#TODO: check system tables instead of using cql thrifteries
|
#TODO: check system tables instead of using cql thrifteries
|
||||||
if not any([name == k.name for k in con.con.client.describe_keyspaces()]):
|
if not any([name == k.name for k in con.client.describe_keyspaces()]):
|
||||||
# if name not in [k.name for k in con.con.client.describe_keyspaces()]:
|
# if name not in [k.name for k in con.con.client.describe_keyspaces()]:
|
||||||
try:
|
try:
|
||||||
#Try the 1.1 method
|
#Try the 1.1 method
|
||||||
con.execute("""CREATE KEYSPACE {}
|
execute("""CREATE KEYSPACE {}
|
||||||
WITH strategy_class = '{}'
|
WITH strategy_class = '{}'
|
||||||
AND strategy_options:replication_factor={};""".format(name, strategy_class, replication_factor))
|
AND strategy_options:replication_factor={};""".format(name, strategy_class, replication_factor))
|
||||||
except CQLEngineException:
|
except CQLEngineException:
|
||||||
@@ -38,12 +38,12 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3,
|
|||||||
if strategy_class != 'SimpleStrategy':
|
if strategy_class != 'SimpleStrategy':
|
||||||
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
|
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
|
||||||
|
|
||||||
con.execute(query)
|
execute(query)
|
||||||
|
|
||||||
def delete_keyspace(name):
|
def delete_keyspace(name):
|
||||||
with connection_manager() as con:
|
with connection_manager() as con:
|
||||||
if name in [k.name for k in con.con.client.describe_keyspaces()]:
|
if name in [k.name for k in con.client.describe_keyspaces()]:
|
||||||
con.execute("DROP KEYSPACE {}".format(name))
|
execute("DROP KEYSPACE {}".format(name))
|
||||||
|
|
||||||
def create_table(model, create_missing_keyspace=True):
|
def create_table(model, create_missing_keyspace=True):
|
||||||
#construct query string
|
#construct query string
|
||||||
@@ -55,78 +55,81 @@ def create_table(model, create_missing_keyspace=True):
|
|||||||
create_keyspace(model._get_keyspace())
|
create_keyspace(model._get_keyspace())
|
||||||
|
|
||||||
with connection_manager() as con:
|
with connection_manager() as con:
|
||||||
#check for an existing column family
|
ks_info = con.client.describe_keyspace(model._get_keyspace())
|
||||||
#TODO: check system tables instead of using cql thrifteries
|
|
||||||
ks_info = con.con.client.describe_keyspace(model._get_keyspace())
|
|
||||||
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
|
|
||||||
qs = ['CREATE TABLE {}'.format(cf_name)]
|
|
||||||
|
|
||||||
#add column types
|
#check for an existing column family
|
||||||
pkeys = []
|
#TODO: check system tables instead of using cql thrifteries
|
||||||
ckeys = []
|
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
|
||||||
qtypes = []
|
qs = ['CREATE TABLE {}'.format(cf_name)]
|
||||||
def add_column(col):
|
|
||||||
s = col.get_column_def()
|
|
||||||
if col.primary_key:
|
|
||||||
keys = (pkeys if col.partition_key else ckeys)
|
|
||||||
keys.append('"{}"'.format(col.db_field_name))
|
|
||||||
qtypes.append(s)
|
|
||||||
for name, col in model._columns.items():
|
|
||||||
add_column(col)
|
|
||||||
|
|
||||||
qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or ''))
|
#add column types
|
||||||
|
pkeys = []
|
||||||
|
ckeys = []
|
||||||
|
qtypes = []
|
||||||
|
def add_column(col):
|
||||||
|
s = col.get_column_def()
|
||||||
|
if col.primary_key:
|
||||||
|
keys = (pkeys if col.partition_key else ckeys)
|
||||||
|
keys.append('"{}"'.format(col.db_field_name))
|
||||||
|
qtypes.append(s)
|
||||||
|
for name, col in model._columns.items():
|
||||||
|
add_column(col)
|
||||||
|
|
||||||
qs += ['({})'.format(', '.join(qtypes))]
|
qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or ''))
|
||||||
|
|
||||||
with_qs = ['read_repair_chance = {}'.format(model.read_repair_chance)]
|
qs += ['({})'.format(', '.join(qtypes))]
|
||||||
|
|
||||||
_order = ["%s %s" % (c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()]
|
with_qs = ['read_repair_chance = {}'.format(model.read_repair_chance)]
|
||||||
if _order:
|
|
||||||
with_qs.append("clustering order by ({})".format(', '.join(_order)))
|
|
||||||
|
|
||||||
# add read_repair_chance
|
_order = ["%s %s" % (c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()]
|
||||||
qs += ['WITH {}'.format(' AND '.join(with_qs))]
|
if _order:
|
||||||
|
with_qs.append("clustering order by ({})".format(', '.join(_order)))
|
||||||
|
|
||||||
|
# add read_repair_chance
|
||||||
|
qs += ['WITH {}'.format(' AND '.join(with_qs))]
|
||||||
|
qs = ' '.join(qs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
execute(qs)
|
||||||
|
except CQLEngineException as ex:
|
||||||
|
# 1.2 doesn't return cf names, so we have to examine the exception
|
||||||
|
# and ignore if it says the column family already exists
|
||||||
|
if "Cannot add already existing column family" not in unicode(ex):
|
||||||
|
raise
|
||||||
|
|
||||||
|
#get existing index names, skip ones that already exist
|
||||||
|
with connection_manager() as con:
|
||||||
|
ks_info = con.client.describe_keyspace(model._get_keyspace())
|
||||||
|
|
||||||
|
cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name]
|
||||||
|
idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else []
|
||||||
|
idx_names = filter(None, idx_names)
|
||||||
|
|
||||||
|
indexes = [c for n,c in model._columns.items() if c.index]
|
||||||
|
if indexes:
|
||||||
|
for column in indexes:
|
||||||
|
if column.db_index_name in idx_names: continue
|
||||||
|
qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)]
|
||||||
|
qs += ['ON {}'.format(cf_name)]
|
||||||
|
qs += ['("{}")'.format(column.db_field_name)]
|
||||||
qs = ' '.join(qs)
|
qs = ' '.join(qs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
con.execute(qs)
|
execute(qs)
|
||||||
except CQLEngineException as ex:
|
except CQLEngineException as ex:
|
||||||
# 1.2 doesn't return cf names, so we have to examine the exception
|
# 1.2 doesn't return cf names, so we have to examine the exception
|
||||||
# and ignore if it says the column family already exists
|
# and ignore if it says the index already exists
|
||||||
if "Cannot add already existing column family" not in unicode(ex):
|
if "Index already exists" not in unicode(ex):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
#get existing index names, skip ones that already exist
|
|
||||||
ks_info = con.con.client.describe_keyspace(model._get_keyspace())
|
|
||||||
cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name]
|
|
||||||
idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else []
|
|
||||||
idx_names = filter(None, idx_names)
|
|
||||||
|
|
||||||
indexes = [c for n,c in model._columns.items() if c.index]
|
|
||||||
if indexes:
|
|
||||||
for column in indexes:
|
|
||||||
if column.db_index_name in idx_names: continue
|
|
||||||
qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)]
|
|
||||||
qs += ['ON {}'.format(cf_name)]
|
|
||||||
qs += ['("{}")'.format(column.db_field_name)]
|
|
||||||
qs = ' '.join(qs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
con.execute(qs)
|
|
||||||
except CQLEngineException as ex:
|
|
||||||
# 1.2 doesn't return cf names, so we have to examine the exception
|
|
||||||
# and ignore if it says the index already exists
|
|
||||||
if "Index already exists" not in unicode(ex):
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def delete_table(model):
|
def delete_table(model):
|
||||||
cf_name = model.column_family_name()
|
cf_name = model.column_family_name()
|
||||||
with connection_manager() as con:
|
|
||||||
try:
|
try:
|
||||||
con.execute('drop table {};'.format(cf_name))
|
execute('drop table {};'.format(cf_name))
|
||||||
except CQLEngineException as ex:
|
except CQLEngineException as ex:
|
||||||
#don't freak out if the table doesn't exist
|
#don't freak out if the table doesn't exist
|
||||||
if 'Cannot drop non existing column family' not in unicode(ex):
|
if 'Cannot drop non existing column family' not in unicode(ex):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@@ -6,7 +6,8 @@ from time import time
|
|||||||
from uuid import uuid1
|
from uuid import uuid1
|
||||||
from cqlengine import BaseContainerColumn, BaseValueManager, Map, columns
|
from cqlengine import BaseContainerColumn, BaseValueManager, Map, columns
|
||||||
|
|
||||||
from cqlengine.connection import connection_manager
|
from cqlengine.connection import connection_pool, connection_manager, execute
|
||||||
|
|
||||||
from cqlengine.exceptions import CQLEngineException
|
from cqlengine.exceptions import CQLEngineException
|
||||||
from cqlengine.functions import QueryValue, Token
|
from cqlengine.functions import QueryValue, Token
|
||||||
|
|
||||||
@@ -193,8 +194,7 @@ class BatchQuery(object):
|
|||||||
|
|
||||||
query_list.append('APPLY BATCH;')
|
query_list.append('APPLY BATCH;')
|
||||||
|
|
||||||
with connection_manager() as con:
|
execute('\n'.join(query_list), parameters)
|
||||||
con.execute('\n'.join(query_list), parameters)
|
|
||||||
|
|
||||||
self.queries = []
|
self.queries = []
|
||||||
|
|
||||||
@@ -346,8 +346,7 @@ class QuerySet(object):
|
|||||||
if self._batch:
|
if self._batch:
|
||||||
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
|
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
|
||||||
if self._result_cache is None:
|
if self._result_cache is None:
|
||||||
self._con = connection_manager()
|
self._cur = execute(self._select_query(), self._where_values())
|
||||||
self._cur = self._con.execute(self._select_query(), self._where_values())
|
|
||||||
self._result_cache = [None]*self._cur.rowcount
|
self._result_cache = [None]*self._cur.rowcount
|
||||||
if self._cur.description:
|
if self._cur.description:
|
||||||
names = [i[0] for i in self._cur.description]
|
names = [i[0] for i in self._cur.description]
|
||||||
@@ -368,7 +367,6 @@ class QuerySet(object):
|
|||||||
|
|
||||||
#return the connection to the connection pool if we have all objects
|
#return the connection to the connection pool if we have all objects
|
||||||
if self._result_cache and self._result_idx == (len(self._result_cache) - 1):
|
if self._result_cache and self._result_idx == (len(self._result_cache) - 1):
|
||||||
self._con.close()
|
|
||||||
self._con = None
|
self._con = None
|
||||||
self._cur = None
|
self._cur = None
|
||||||
|
|
||||||
@@ -555,9 +553,8 @@ class QuerySet(object):
|
|||||||
|
|
||||||
qs = ' '.join(qs)
|
qs = ' '.join(qs)
|
||||||
|
|
||||||
with connection_manager() as con:
|
cur = execute(qs, self._where_values())
|
||||||
cur = con.execute(qs, self._where_values())
|
return cur.fetchone()[0]
|
||||||
return cur.fetchone()[0]
|
|
||||||
else:
|
else:
|
||||||
return len(self._result_cache)
|
return len(self._result_cache)
|
||||||
|
|
||||||
@@ -635,8 +632,7 @@ class QuerySet(object):
|
|||||||
if self._batch:
|
if self._batch:
|
||||||
self._batch.add_query(qs, self._where_values())
|
self._batch.add_query(qs, self._where_values())
|
||||||
else:
|
else:
|
||||||
with connection_manager() as con:
|
execute(qs, self._where_values())
|
||||||
con.execute(qs, self._where_values())
|
|
||||||
|
|
||||||
def values_list(self, *fields, **kwargs):
|
def values_list(self, *fields, **kwargs):
|
||||||
""" Instructs the query set to return tuples, not model instance """
|
""" Instructs the query set to return tuples, not model instance """
|
||||||
@@ -753,8 +749,7 @@ class DMLQuery(object):
|
|||||||
if self.batch:
|
if self.batch:
|
||||||
self.batch.add_query(qs, query_values)
|
self.batch.add_query(qs, query_values)
|
||||||
else:
|
else:
|
||||||
with connection_manager() as con:
|
execute(qs, query_values)
|
||||||
con.execute(qs, query_values)
|
|
||||||
|
|
||||||
|
|
||||||
# delete nulled columns and removed map keys
|
# delete nulled columns and removed map keys
|
||||||
@@ -787,8 +782,7 @@ class DMLQuery(object):
|
|||||||
if self.batch:
|
if self.batch:
|
||||||
self.batch.add_query(qs, query_values)
|
self.batch.add_query(qs, query_values)
|
||||||
else:
|
else:
|
||||||
with connection_manager() as con:
|
execute(qs, query_values)
|
||||||
con.execute(qs, query_values)
|
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
""" Deletes one instance """
|
""" Deletes one instance """
|
||||||
@@ -809,7 +803,6 @@ class DMLQuery(object):
|
|||||||
if self.batch:
|
if self.batch:
|
||||||
self.batch.add_query(qs, field_values)
|
self.batch.add_query(qs, field_values)
|
||||||
else:
|
else:
|
||||||
with connection_manager() as con:
|
execute(qs, field_values)
|
||||||
con.execute(qs, field_values)
|
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,13 +1,13 @@
|
|||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
from cqlengine import connection
|
from cqlengine import connection
|
||||||
|
|
||||||
class BaseCassEngTestCase(TestCase):
|
class BaseCassEngTestCase(TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super(BaseCassEngTestCase, cls).setUpClass()
|
super(BaseCassEngTestCase, cls).setUpClass()
|
||||||
if not connection._connection_pool:
|
# todo fix
|
||||||
connection.setup(['localhost:9160'], default_keyspace='cqlengine_test')
|
connection.setup(['localhost:9160'], default_keyspace='cqlengine_test')
|
||||||
|
|
||||||
def assertHasAttr(self, obj, attr):
|
def assertHasAttr(self, obj, attr):
|
||||||
self.assertTrue(hasattr(obj, attr),
|
self.assertTrue(hasattr(obj, attr),
|
||||||
|
@@ -1,43 +1,50 @@
|
|||||||
|
from cqlengine.exceptions import CQLEngineException
|
||||||
from cqlengine.management import create_table, delete_table
|
from cqlengine.management import create_table, delete_table
|
||||||
from cqlengine.tests.base import BaseCassEngTestCase
|
from cqlengine.tests.base import BaseCassEngTestCase
|
||||||
|
|
||||||
from cqlengine.connection import ConnectionPool
|
from cqlengine.connection import ConnectionPool, Host
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock, MagicMock, MagicProxy, patch
|
||||||
from cqlengine import management
|
from cqlengine import management
|
||||||
from cqlengine.tests.query.test_queryset import TestModel
|
from cqlengine.tests.query.test_queryset import TestModel
|
||||||
|
|
||||||
|
from cql.thrifteries import ThriftConnection
|
||||||
|
|
||||||
class ConnectionPoolTestCase(BaseCassEngTestCase):
|
class ConnectionPoolFailoverTestCase(BaseCassEngTestCase):
|
||||||
"""Test cassandra connection pooling."""
|
"""Test cassandra connection pooling."""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
ConnectionPool.clear()
|
self.host = Host('127.0.0.1', '9160')
|
||||||
|
self.pool = ConnectionPool([self.host])
|
||||||
|
|
||||||
def test_should_create_single_connection_on_request(self):
|
def test_totally_dead_pool(self):
|
||||||
"""Should create a single connection on first request"""
|
# kill the con
|
||||||
result = ConnectionPool.get()
|
with patch('cqlengine.connection.cql.connect') as mock:
|
||||||
self.assertIsNotNone(result)
|
mock.side_effect=CQLEngineException
|
||||||
self.assertEquals(0, ConnectionPool._queue.qsize())
|
with self.assertRaises(CQLEngineException):
|
||||||
ConnectionPool._queue.put(result)
|
self.pool.execute("select * from system.peers", {})
|
||||||
self.assertEquals(1, ConnectionPool._queue.qsize())
|
|
||||||
|
|
||||||
def test_should_close_connection_if_queue_is_full(self):
|
def test_dead_node(self):
|
||||||
"""Should close additional connections if queue is full"""
|
self.pool._hosts.append(self.host)
|
||||||
connections = [ConnectionPool.get() for x in range(10)]
|
|
||||||
for conn in connections:
|
|
||||||
ConnectionPool.put(conn)
|
|
||||||
fake_conn = Mock()
|
|
||||||
ConnectionPool.put(fake_conn)
|
|
||||||
fake_conn.close.assert_called_once_with()
|
|
||||||
|
|
||||||
def test_should_pop_connections_from_queue(self):
|
# cursor mock needed so set_cql_version doesn't crap out
|
||||||
"""Should pull existing connections off of the queue"""
|
ok_cur = MagicMock()
|
||||||
conn = ConnectionPool.get()
|
|
||||||
ConnectionPool.put(conn)
|
ok_conn = MagicMock()
|
||||||
self.assertEquals(1, ConnectionPool._queue.qsize())
|
ok_conn.return_value = ok_cur
|
||||||
self.assertEquals(conn, ConnectionPool.get())
|
|
||||||
self.assertEquals(0, ConnectionPool._queue.qsize())
|
|
||||||
|
returns = [CQLEngineException(), ok_conn]
|
||||||
|
|
||||||
|
def side_effect(*args, **kwargs):
|
||||||
|
result = returns.pop(0)
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
return result
|
||||||
|
|
||||||
|
with patch('cqlengine.connection.cql.connect') as mock:
|
||||||
|
mock.side_effect = side_effect
|
||||||
|
conn = self.pool._create_connection()
|
||||||
|
|
||||||
|
|
||||||
class CreateKeyspaceTest(BaseCassEngTestCase):
|
class CreateKeyspaceTest(BaseCassEngTestCase):
|
||||||
|
@@ -405,18 +405,7 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage):
|
|||||||
assert q._con is None
|
assert q._con is None
|
||||||
assert q._cur is None
|
assert q._cur is None
|
||||||
|
|
||||||
def test_conn_is_returned_after_queryset_is_garbage_collected(self):
|
|
||||||
""" Tests that the connection is returned to the connection pool after the queryset is gc'd """
|
|
||||||
from cqlengine.connection import ConnectionPool
|
|
||||||
# The queue size can be 1 if we just run this file's tests
|
|
||||||
# It will be 2 when we run 'em all
|
|
||||||
initial_size = ConnectionPool._queue.qsize()
|
|
||||||
q = TestModel.objects(test_id=0)
|
|
||||||
v = q[0]
|
|
||||||
assert ConnectionPool._queue.qsize() == initial_size - 1
|
|
||||||
|
|
||||||
del q
|
|
||||||
assert ConnectionPool._queue.qsize() == initial_size
|
|
||||||
|
|
||||||
class TimeUUIDQueryModel(Model):
|
class TimeUUIDQueryModel(Model):
|
||||||
partition = columns.UUID(primary_key=True)
|
partition = columns.UUID(primary_key=True)
|
||||||
|
Reference in New Issue
Block a user