added fault tolerant connection manager, which will be easily extended to support connection pooling
added support for multiple connections removed dependence on setting keyspace name on opening connection
This commit is contained in:
@@ -2,17 +2,67 @@
|
|||||||
#http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2 /
|
#http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2 /
|
||||||
#http://cassandra.apache.org/doc/cql/CQL.html
|
#http://cassandra.apache.org/doc/cql/CQL.html
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
import random
|
||||||
|
|
||||||
import cql
|
import cql
|
||||||
|
|
||||||
from cqlengine.exceptions import CQLEngineException
|
from cqlengine.exceptions import CQLEngineException
|
||||||
|
|
||||||
|
from thrift.transport.TTransport import TTransportException
|
||||||
|
|
||||||
|
|
||||||
class CQLConnectionError(CQLEngineException): pass
|
class CQLConnectionError(CQLEngineException): pass
|
||||||
|
|
||||||
_keyspace = 'cassengine_test'
|
Host = namedtuple('Host', ['name', 'port'])
|
||||||
|
_hosts = []
|
||||||
|
_host_idx = 0
|
||||||
|
_conn= None
|
||||||
|
_username = None
|
||||||
|
_password = None
|
||||||
|
|
||||||
|
def _set_conn(host):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
global _conn
|
||||||
|
_conn = cql.connect(host.name, host.port, user=_username, password=_password)
|
||||||
|
_conn.set_cql_version('3.0.0')
|
||||||
|
|
||||||
|
def setup(hosts, username=None, password=None):
|
||||||
|
"""
|
||||||
|
Records the hosts and connects to one of them
|
||||||
|
|
||||||
|
:param hosts: list of hosts, strings in the <hostname>:<port>, or just <hostname>
|
||||||
|
"""
|
||||||
|
global _hosts
|
||||||
|
global _username
|
||||||
|
global _password
|
||||||
|
|
||||||
|
_username = username
|
||||||
|
_password = password
|
||||||
|
|
||||||
|
for host in hosts:
|
||||||
|
host = host.strip()
|
||||||
|
host = host.split(':')
|
||||||
|
if len(host) == 1:
|
||||||
|
_hosts.append(Host(host[0], 9160))
|
||||||
|
elif len(host) == 2:
|
||||||
|
_hosts.append(Host(*host))
|
||||||
|
else:
|
||||||
|
raise CQLConnectionError("Can't parse {}".format(''.join(host)))
|
||||||
|
|
||||||
|
if not _hosts:
|
||||||
|
raise CQLConnectionError("At least one host required")
|
||||||
|
|
||||||
|
random.shuffle(_hosts)
|
||||||
|
host = _hosts[_host_idx]
|
||||||
|
_set_conn(host)
|
||||||
|
|
||||||
|
|
||||||
#TODO: look into the cql connection pool class
|
#TODO: look into the cql connection pool class
|
||||||
_conn = {}
|
_old_conn = {}
|
||||||
def get_connection(keyspace, create_missing_keyspace=True):
|
def get_connection(keyspace=None, create_missing_keyspace=True):
|
||||||
con = _conn.get(keyspace)
|
con = _old_conn.get(keyspace)
|
||||||
if con is None:
|
if con is None:
|
||||||
con = cql.connect('127.0.0.1', 9160)
|
con = cql.connect('127.0.0.1', 9160)
|
||||||
con.set_cql_version('3.0.0')
|
con.set_cql_version('3.0.0')
|
||||||
@@ -27,7 +77,50 @@ def get_connection(keyspace, create_missing_keyspace=True):
|
|||||||
else:
|
else:
|
||||||
raise CQLConnectionError('"{}" is not an existing keyspace'.format(keyspace))
|
raise CQLConnectionError('"{}" is not an existing keyspace'.format(keyspace))
|
||||||
|
|
||||||
_conn[keyspace] = con
|
_old_conn[keyspace] = con
|
||||||
|
|
||||||
return con
|
return con
|
||||||
|
|
||||||
|
class connection_manager(object):
|
||||||
|
"""
|
||||||
|
Connection failure tolerant connection manager. Written to be used in a 'with' block for connection pooling
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
if not _hosts:
|
||||||
|
raise CQLConnectionError("No connections have been configured, call cqlengine.connection.setup")
|
||||||
|
self.keyspace = None
|
||||||
|
self.con = _conn
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def execute(self, query, params={}):
|
||||||
|
"""
|
||||||
|
Gets a connection from the pool and executes the given query, returns the cursor
|
||||||
|
|
||||||
|
if there's a connection problem, this will silently create a new connection pool
|
||||||
|
from the available hosts, and remove the problematic host from the host list
|
||||||
|
"""
|
||||||
|
global _host_idx
|
||||||
|
|
||||||
|
for i in range(len(_hosts)):
|
||||||
|
try:
|
||||||
|
cur = self.con.cursor()
|
||||||
|
cur.execute(query, params)
|
||||||
|
return cur
|
||||||
|
except TTransportException:
|
||||||
|
#TODO: check for other errors raised in the event of a connection / server problem
|
||||||
|
#move to the next connection and set the connection pool
|
||||||
|
self.con_pool.return_connection(self.con)
|
||||||
|
self.con = None
|
||||||
|
_host_idx += 1
|
||||||
|
_host_idx %= len(_hosts)
|
||||||
|
host = _hosts[_host_idx]
|
||||||
|
_set_conn(host)
|
||||||
|
self.con = _conn
|
||||||
|
|
||||||
|
raise CQLConnectionError("couldn't reach a Cassandra server")
|
||||||
|
|
||||||
|
|||||||
@@ -1,27 +1,25 @@
|
|||||||
from cqlengine.connection import get_connection
|
from cqlengine.connection import get_connection
|
||||||
|
from cqlengine.connection import connection_manager
|
||||||
|
|
||||||
def create_keyspace(name):
|
def create_keyspace(name):
|
||||||
con = get_connection(None)
|
with connection_manager() as con:
|
||||||
cur = con.cursor()
|
con.execute("""CREATE KEYSPACE {}
|
||||||
cur.execute("""CREATE KEYSPACE {}
|
|
||||||
WITH strategy_class = 'SimpleStrategy'
|
WITH strategy_class = 'SimpleStrategy'
|
||||||
AND strategy_options:replication_factor=1;""".format(name))
|
AND strategy_options:replication_factor=1;""".format(name))
|
||||||
|
|
||||||
def delete_keyspace(name):
|
def delete_keyspace(name):
|
||||||
con = get_connection(None)
|
with connection_manager() as con:
|
||||||
cur = con.cursor()
|
con.execute("DROP KEYSPACE {}".format(name))
|
||||||
cur.execute("DROP KEYSPACE {}".format(name))
|
|
||||||
|
|
||||||
def create_column_family(model):
|
def create_column_family(model):
|
||||||
#construct query string
|
#construct query string
|
||||||
cf_name = model.column_family_name()
|
cf_name = model.column_family_name()
|
||||||
|
raw_cf_name = model.column_family_name(include_keyspace=False)
|
||||||
|
|
||||||
conn = get_connection(model.keyspace)
|
with connection_manager() as con:
|
||||||
cur = conn.cursor()
|
|
||||||
|
|
||||||
#check for an existing column family
|
#check for an existing column family
|
||||||
ks_info = conn.client.describe_keyspace(model.keyspace)
|
ks_info = con.con.client.describe_keyspace(model.keyspace)
|
||||||
if not any([cf_name == cf.name for cf in ks_info.cf_defs]):
|
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
|
||||||
qs = ['CREATE TABLE {}'.format(cf_name)]
|
qs = ['CREATE TABLE {}'.format(cf_name)]
|
||||||
|
|
||||||
#add column types
|
#add column types
|
||||||
@@ -39,7 +37,7 @@ def create_column_family(model):
|
|||||||
qs += ['({})'.format(', '.join(qtypes))]
|
qs += ['({})'.format(', '.join(qtypes))]
|
||||||
qs = ' '.join(qs)
|
qs = ' '.join(qs)
|
||||||
|
|
||||||
cur.execute(qs)
|
con.execute(qs)
|
||||||
|
|
||||||
indexes = [c for n,c in model._columns.items() if c.index]
|
indexes = [c for n,c in model._columns.items() if c.index]
|
||||||
if indexes:
|
if indexes:
|
||||||
@@ -50,15 +48,15 @@ def create_column_family(model):
|
|||||||
qs += ['ON {}'.format(cf_name)]
|
qs += ['ON {}'.format(cf_name)]
|
||||||
qs += ['({})'.format(column.db_field_name)]
|
qs += ['({})'.format(column.db_field_name)]
|
||||||
qs = ' '.join(qs)
|
qs = ' '.join(qs)
|
||||||
cur.execute(qs)
|
con.execute(qs)
|
||||||
|
|
||||||
|
|
||||||
def delete_column_family(model):
|
def delete_column_family(model):
|
||||||
#check that model exists
|
#check that model exists
|
||||||
cf_name = model.column_family_name()
|
cf_name = model.column_family_name()
|
||||||
conn = get_connection(model.keyspace)
|
raw_cf_name = model.column_family_name(include_keyspace=False)
|
||||||
ks_info = conn.client.describe_keyspace(model.keyspace)
|
with connection_manager() as con:
|
||||||
if any([cf_name == cf.name for cf in ks_info.cf_defs]):
|
ks_info = con.con.client.describe_keyspace(model.keyspace)
|
||||||
cur = conn.cursor()
|
if any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
|
||||||
cur.execute('drop table {};'.format(cf_name))
|
con.execute('drop table {};'.format(cf_name))
|
||||||
|
|
||||||
|
|||||||
@@ -24,10 +24,8 @@ class BaseModel(object):
|
|||||||
value_mngr = column.value_manager(self, column, values.get(name, None))
|
value_mngr = column.value_manager(self, column, values.get(name, None))
|
||||||
self._values[name] = value_mngr
|
self._values[name] = value_mngr
|
||||||
|
|
||||||
#TODO: note any absent fields so they're not deleted
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def column_family_name(cls):
|
def column_family_name(cls, include_keyspace=True):
|
||||||
"""
|
"""
|
||||||
Returns the column family name if it's been defined
|
Returns the column family name if it's been defined
|
||||||
otherwise, it creates it from the module and class name
|
otherwise, it creates it from the module and class name
|
||||||
@@ -38,7 +36,9 @@ class BaseModel(object):
|
|||||||
cf_name = cf_name.replace('.', '_')
|
cf_name = cf_name.replace('.', '_')
|
||||||
#trim to less than 48 characters or cassandra will complain
|
#trim to less than 48 characters or cassandra will complain
|
||||||
cf_name = cf_name[-48:]
|
cf_name = cf_name[-48:]
|
||||||
return cf_name.lower()
|
cf_name = cf_name.lower()
|
||||||
|
if not include_keyspace: return cf_name
|
||||||
|
return '{}.{}'.format(cls.keyspace, cf_name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pk(self):
|
def pk(self):
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from hashlib import md5
|
|||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from cqlengine.connection import get_connection
|
from cqlengine.connection import get_connection
|
||||||
|
from cqlengine.connection import connection_manager
|
||||||
from cqlengine.exceptions import CQLEngineException
|
from cqlengine.exceptions import CQLEngineException
|
||||||
|
|
||||||
#CQL 3 reference:
|
#CQL 3 reference:
|
||||||
@@ -220,7 +221,7 @@ class QuerySet(object):
|
|||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
#TODO: cache results
|
#TODO: cache results
|
||||||
if self._cursor is None:
|
if self._cursor is None:
|
||||||
conn = get_connection(self.model.keyspace)
|
conn = get_connection()
|
||||||
self._cursor = conn.cursor()
|
self._cursor = conn.cursor()
|
||||||
self._cursor.execute(self._select_query(), self._where_values())
|
self._cursor.execute(self._select_query(), self._where_values())
|
||||||
self._rowcount = self._cursor.rowcount
|
self._rowcount = self._cursor.rowcount
|
||||||
@@ -350,9 +351,8 @@ class QuerySet(object):
|
|||||||
qs += ['WHERE {}'.format(self._where_clause())]
|
qs += ['WHERE {}'.format(self._where_clause())]
|
||||||
qs = ' '.join(qs)
|
qs = ' '.join(qs)
|
||||||
|
|
||||||
con = get_connection(self.model.keyspace)
|
with connection_manager() as con:
|
||||||
cur = con.cursor()
|
cur = con.execute(qs, self._where_values())
|
||||||
cur.execute(qs, self._where_values())
|
|
||||||
return cur.fetchone()[0]
|
return cur.fetchone()[0]
|
||||||
|
|
||||||
def limit(self, v):
|
def limit(self, v):
|
||||||
@@ -430,11 +430,10 @@ class QuerySet(object):
|
|||||||
qs += ["({})".format(', '.join([':'+f for f in field_names]))]
|
qs += ["({})".format(', '.join([':'+f for f in field_names]))]
|
||||||
qs = ' '.join(qs)
|
qs = ' '.join(qs)
|
||||||
|
|
||||||
conn = get_connection(self.model.keyspace)
|
with connection_manager() as con:
|
||||||
cur = conn.cursor()
|
con.execute(qs, field_values)
|
||||||
cur.execute(qs, field_values)
|
|
||||||
|
|
||||||
#TODO: delete deleted / nulled columns
|
#delete deleted / nulled columns
|
||||||
deleted = [k for k,v in instance._values.items() if v.deleted]
|
deleted = [k for k,v in instance._values.items() if v.deleted]
|
||||||
if deleted:
|
if deleted:
|
||||||
del_fields = [self.model._columns[f] for f in deleted]
|
del_fields = [self.model._columns[f] for f in deleted]
|
||||||
@@ -448,9 +447,9 @@ class QuerySet(object):
|
|||||||
qs = ' '.join(qs)
|
qs = ' '.join(qs)
|
||||||
|
|
||||||
pk_dict = dict([(v.db_field_name, getattr(instance, k)) for k,v in pks.items()])
|
pk_dict = dict([(v.db_field_name, getattr(instance, k)) for k,v in pks.items()])
|
||||||
cur.execute(qs, pk_dict)
|
|
||||||
|
|
||||||
|
|
||||||
|
with connection_manager() as con:
|
||||||
|
con.execute(qs, pk_dict)
|
||||||
|
|
||||||
|
|
||||||
def create(self, **kwargs):
|
def create(self, **kwargs):
|
||||||
@@ -466,9 +465,8 @@ class QuerySet(object):
|
|||||||
qs += ['WHERE {}'.format(self._where_clause())]
|
qs += ['WHERE {}'.format(self._where_clause())]
|
||||||
qs = ' '.join(qs)
|
qs = ' '.join(qs)
|
||||||
|
|
||||||
con = get_connection(self.model.keyspace)
|
with connection_manager() as con:
|
||||||
cur = con.cursor()
|
con.execute(qs, self._where_values())
|
||||||
cur.execute(qs, self._where_values())
|
|
||||||
|
|
||||||
|
|
||||||
def delete_instance(self, instance):
|
def delete_instance(self, instance):
|
||||||
@@ -478,8 +476,7 @@ class QuerySet(object):
|
|||||||
qs += ['WHERE {0}=:{0}'.format(pk_name)]
|
qs += ['WHERE {0}=:{0}'.format(pk_name)]
|
||||||
qs = ' '.join(qs)
|
qs = ' '.join(qs)
|
||||||
|
|
||||||
conn = get_connection(self.model.keyspace)
|
with connection_manager() as con:
|
||||||
cur = conn.cursor()
|
con.execute(qs, {pk_name:instance.pk})
|
||||||
cur.execute(qs, {pk_name:instance.pk})
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,14 @@
|
|||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
from cqlengine import connection
|
||||||
|
|
||||||
class BaseCassEngTestCase(TestCase):
|
class BaseCassEngTestCase(TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super(BaseCassEngTestCase, cls).setUpClass()
|
||||||
|
if not connection._hosts:
|
||||||
|
connection.setup(['localhost'])
|
||||||
|
|
||||||
def assertHasAttr(self, obj, attr):
|
def assertHasAttr(self, obj, attr):
|
||||||
self.assertTrue(hasattr(obj, attr),
|
self.assertTrue(hasattr(obj, attr),
|
||||||
"{} doesn't have attribute: {}".format(obj, attr))
|
"{} doesn't have attribute: {}".format(obj, attr))
|
||||||
|
|||||||
0
cqlengine/tests/management/__init__.py
Normal file
0
cqlengine/tests/management/__init__.py
Normal file
0
cqlengine/tests/management/test_management.py
Normal file
0
cqlengine/tests/management/test_management.py
Normal file
Reference in New Issue
Block a user