added fault tolerant connection manager, which will be easily extended to support connection pooling

added support for multiple connections
removed dependence on setting keyspace name on opening connection
This commit is contained in:
Blake Eggleston
2012-11-24 23:11:09 -08:00
parent 823e21fdee
commit 2db6426d02
7 changed files with 166 additions and 71 deletions

View File

@@ -2,17 +2,67 @@
#http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2 /
#http://cassandra.apache.org/doc/cql/CQL.html
from collections import namedtuple
import random
import cql
from cqlengine.exceptions import CQLEngineException
from thrift.transport.TTransport import TTransportException
class CQLConnectionError(CQLEngineException): pass
_keyspace = 'cassengine_test'
Host = namedtuple('Host', ['name', 'port'])
_hosts = []
_host_idx = 0
_conn= None
_username = None
_password = None
def _set_conn(host):
"""
"""
global _conn
_conn = cql.connect(host.name, host.port, user=_username, password=_password)
_conn.set_cql_version('3.0.0')
def setup(hosts, username=None, password=None):
"""
Records the hosts and connects to one of them
:param hosts: list of hosts, strings in the <hostname>:<port>, or just <hostname>
"""
global _hosts
global _username
global _password
_username = username
_password = password
for host in hosts:
host = host.strip()
host = host.split(':')
if len(host) == 1:
_hosts.append(Host(host[0], 9160))
elif len(host) == 2:
_hosts.append(Host(*host))
else:
raise CQLConnectionError("Can't parse {}".format(''.join(host)))
if not _hosts:
raise CQLConnectionError("At least one host required")
random.shuffle(_hosts)
host = _hosts[_host_idx]
_set_conn(host)
#TODO: look into the cql connection pool class
_conn = {}
def get_connection(keyspace, create_missing_keyspace=True):
con = _conn.get(keyspace)
_old_conn = {}
def get_connection(keyspace=None, create_missing_keyspace=True):
con = _old_conn.get(keyspace)
if con is None:
con = cql.connect('127.0.0.1', 9160)
con.set_cql_version('3.0.0')
@@ -27,7 +77,50 @@ def get_connection(keyspace, create_missing_keyspace=True):
else:
raise CQLConnectionError('"{}" is not an existing keyspace'.format(keyspace))
_conn[keyspace] = con
_old_conn[keyspace] = con
return con
class connection_manager(object):
"""
Connection failure tolerant connection manager. Written to be used in a 'with' block for connection pooling
"""
def __init__(self):
if not _hosts:
raise CQLConnectionError("No connections have been configured, call cqlengine.connection.setup")
self.keyspace = None
self.con = _conn
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
pass
def execute(self, query, params={}):
"""
Gets a connection from the pool and executes the given query, returns the cursor
if there's a connection problem, this will silently create a new connection pool
from the available hosts, and remove the problematic host from the host list
"""
global _host_idx
for i in range(len(_hosts)):
try:
cur = self.con.cursor()
cur.execute(query, params)
return cur
except TTransportException:
#TODO: check for other errors raised in the event of a connection / server problem
#move to the next connection and set the connection pool
self.con_pool.return_connection(self.con)
self.con = None
_host_idx += 1
_host_idx %= len(_hosts)
host = _hosts[_host_idx]
_set_conn(host)
self.con = _conn
raise CQLConnectionError("couldn't reach a Cassandra server")

View File

@@ -1,64 +1,62 @@
from cqlengine.connection import get_connection
from cqlengine.connection import connection_manager
def create_keyspace(name):
con = get_connection(None)
cur = con.cursor()
cur.execute("""CREATE KEYSPACE {}
WITH strategy_class = 'SimpleStrategy'
AND strategy_options:replication_factor=1;""".format(name))
with connection_manager() as con:
con.execute("""CREATE KEYSPACE {}
WITH strategy_class = 'SimpleStrategy'
AND strategy_options:replication_factor=1;""".format(name))
def delete_keyspace(name):
con = get_connection(None)
cur = con.cursor()
cur.execute("DROP KEYSPACE {}".format(name))
with connection_manager() as con:
con.execute("DROP KEYSPACE {}".format(name))
def create_column_family(model):
#construct query string
cf_name = model.column_family_name()
raw_cf_name = model.column_family_name(include_keyspace=False)
conn = get_connection(model.keyspace)
cur = conn.cursor()
with connection_manager() as con:
#check for an existing column family
ks_info = con.con.client.describe_keyspace(model.keyspace)
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
qs = ['CREATE TABLE {}'.format(cf_name)]
#check for an existing column family
ks_info = conn.client.describe_keyspace(model.keyspace)
if not any([cf_name == cf.name for cf in ks_info.cf_defs]):
qs = ['CREATE TABLE {}'.format(cf_name)]
#add column types
pkeys = []
qtypes = []
def add_column(col):
s = '{} {}'.format(col.db_field_name, col.db_type)
if col.primary_key: pkeys.append(col.db_field_name)
qtypes.append(s)
for name, col in model._columns.items():
add_column(col)
#add column types
pkeys = []
qtypes = []
def add_column(col):
s = '{} {}'.format(col.db_field_name, col.db_type)
if col.primary_key: pkeys.append(col.db_field_name)
qtypes.append(s)
for name, col in model._columns.items():
add_column(col)
qtypes.append('PRIMARY KEY ({})'.format(', '.join(pkeys)))
qtypes.append('PRIMARY KEY ({})'.format(', '.join(pkeys)))
qs += ['({})'.format(', '.join(qtypes))]
qs = ' '.join(qs)
cur.execute(qs)
indexes = [c for n,c in model._columns.items() if c.index]
if indexes:
for column in indexes:
#TODO: check for existing index...
#can that be determined from the connection client?
qs = ['CREATE INDEX {}'.format(column.db_index_name)]
qs += ['ON {}'.format(cf_name)]
qs += ['({})'.format(column.db_field_name)]
qs += ['({})'.format(', '.join(qtypes))]
qs = ' '.join(qs)
cur.execute(qs)
con.execute(qs)
indexes = [c for n,c in model._columns.items() if c.index]
if indexes:
for column in indexes:
#TODO: check for existing index...
#can that be determined from the connection client?
qs = ['CREATE INDEX {}'.format(column.db_index_name)]
qs += ['ON {}'.format(cf_name)]
qs += ['({})'.format(column.db_field_name)]
qs = ' '.join(qs)
con.execute(qs)
def delete_column_family(model):
#check that model exists
cf_name = model.column_family_name()
conn = get_connection(model.keyspace)
ks_info = conn.client.describe_keyspace(model.keyspace)
if any([cf_name == cf.name for cf in ks_info.cf_defs]):
cur = conn.cursor()
cur.execute('drop table {};'.format(cf_name))
raw_cf_name = model.column_family_name(include_keyspace=False)
with connection_manager() as con:
ks_info = con.con.client.describe_keyspace(model.keyspace)
if any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
con.execute('drop table {};'.format(cf_name))

View File

@@ -24,10 +24,8 @@ class BaseModel(object):
value_mngr = column.value_manager(self, column, values.get(name, None))
self._values[name] = value_mngr
#TODO: note any absent fields so they're not deleted
@classmethod
def column_family_name(cls):
def column_family_name(cls, include_keyspace=True):
"""
Returns the column family name if it's been defined
otherwise, it creates it from the module and class name
@@ -38,7 +36,9 @@ class BaseModel(object):
cf_name = cf_name.replace('.', '_')
#trim to less than 48 characters or cassandra will complain
cf_name = cf_name[-48:]
return cf_name.lower()
cf_name = cf_name.lower()
if not include_keyspace: return cf_name
return '{}.{}'.format(cls.keyspace, cf_name)
@property
def pk(self):

View File

@@ -4,6 +4,7 @@ from hashlib import md5
from time import time
from cqlengine.connection import get_connection
from cqlengine.connection import connection_manager
from cqlengine.exceptions import CQLEngineException
#CQL 3 reference:
@@ -220,7 +221,7 @@ class QuerySet(object):
def __iter__(self):
#TODO: cache results
if self._cursor is None:
conn = get_connection(self.model.keyspace)
conn = get_connection()
self._cursor = conn.cursor()
self._cursor.execute(self._select_query(), self._where_values())
self._rowcount = self._cursor.rowcount
@@ -350,10 +351,9 @@ class QuerySet(object):
qs += ['WHERE {}'.format(self._where_clause())]
qs = ' '.join(qs)
con = get_connection(self.model.keyspace)
cur = con.cursor()
cur.execute(qs, self._where_values())
return cur.fetchone()[0]
with connection_manager() as con:
cur = con.execute(qs, self._where_values())
return cur.fetchone()[0]
def limit(self, v):
"""
@@ -430,11 +430,10 @@ class QuerySet(object):
qs += ["({})".format(', '.join([':'+f for f in field_names]))]
qs = ' '.join(qs)
conn = get_connection(self.model.keyspace)
cur = conn.cursor()
cur.execute(qs, field_values)
with connection_manager() as con:
con.execute(qs, field_values)
#TODO: delete deleted / nulled columns
#delete deleted / nulled columns
deleted = [k for k,v in instance._values.items() if v.deleted]
if deleted:
del_fields = [self.model._columns[f] for f in deleted]
@@ -448,11 +447,11 @@ class QuerySet(object):
qs = ' '.join(qs)
pk_dict = dict([(v.db_field_name, getattr(instance, k)) for k,v in pks.items()])
cur.execute(qs, pk_dict)
with connection_manager() as con:
con.execute(qs, pk_dict)
def create(self, **kwargs):
return self.model(**kwargs).save()
@@ -466,9 +465,8 @@ class QuerySet(object):
qs += ['WHERE {}'.format(self._where_clause())]
qs = ' '.join(qs)
con = get_connection(self.model.keyspace)
cur = con.cursor()
cur.execute(qs, self._where_values())
with connection_manager() as con:
con.execute(qs, self._where_values())
def delete_instance(self, instance):
@@ -478,8 +476,7 @@ class QuerySet(object):
qs += ['WHERE {0}=:{0}'.format(pk_name)]
qs = ' '.join(qs)
conn = get_connection(self.model.keyspace)
cur = conn.cursor()
cur.execute(qs, {pk_name:instance.pk})
with connection_manager() as con:
con.execute(qs, {pk_name:instance.pk})

View File

@@ -1,7 +1,14 @@
from unittest import TestCase
from cqlengine import connection
class BaseCassEngTestCase(TestCase):
@classmethod
def setUpClass(cls):
super(BaseCassEngTestCase, cls).setUpClass()
if not connection._hosts:
connection.setup(['localhost'])
def assertHasAttr(self, obj, attr):
self.assertTrue(hasattr(obj, attr),
"{} doesn't have attribute: {}".format(obj, attr))

View File