added fault tolerant connection manager, which will be easily extended to support connection pooling
added support for multiple connections removed dependence on setting keyspace name on opening connection
This commit is contained in:
@@ -2,17 +2,67 @@
|
||||
#http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2 /
|
||||
#http://cassandra.apache.org/doc/cql/CQL.html
|
||||
|
||||
from collections import namedtuple
|
||||
import random
|
||||
|
||||
import cql
|
||||
|
||||
from cqlengine.exceptions import CQLEngineException
|
||||
|
||||
from thrift.transport.TTransport import TTransportException
|
||||
|
||||
|
||||
class CQLConnectionError(CQLEngineException): pass
|
||||
|
||||
_keyspace = 'cassengine_test'
|
||||
Host = namedtuple('Host', ['name', 'port'])
|
||||
_hosts = []
|
||||
_host_idx = 0
|
||||
_conn= None
|
||||
_username = None
|
||||
_password = None
|
||||
|
||||
def _set_conn(host):
|
||||
"""
|
||||
"""
|
||||
global _conn
|
||||
_conn = cql.connect(host.name, host.port, user=_username, password=_password)
|
||||
_conn.set_cql_version('3.0.0')
|
||||
|
||||
def setup(hosts, username=None, password=None):
|
||||
"""
|
||||
Records the hosts and connects to one of them
|
||||
|
||||
:param hosts: list of hosts, strings in the <hostname>:<port>, or just <hostname>
|
||||
"""
|
||||
global _hosts
|
||||
global _username
|
||||
global _password
|
||||
|
||||
_username = username
|
||||
_password = password
|
||||
|
||||
for host in hosts:
|
||||
host = host.strip()
|
||||
host = host.split(':')
|
||||
if len(host) == 1:
|
||||
_hosts.append(Host(host[0], 9160))
|
||||
elif len(host) == 2:
|
||||
_hosts.append(Host(*host))
|
||||
else:
|
||||
raise CQLConnectionError("Can't parse {}".format(''.join(host)))
|
||||
|
||||
if not _hosts:
|
||||
raise CQLConnectionError("At least one host required")
|
||||
|
||||
random.shuffle(_hosts)
|
||||
host = _hosts[_host_idx]
|
||||
_set_conn(host)
|
||||
|
||||
|
||||
#TODO: look into the cql connection pool class
|
||||
_conn = {}
|
||||
def get_connection(keyspace, create_missing_keyspace=True):
|
||||
con = _conn.get(keyspace)
|
||||
_old_conn = {}
|
||||
def get_connection(keyspace=None, create_missing_keyspace=True):
|
||||
con = _old_conn.get(keyspace)
|
||||
if con is None:
|
||||
con = cql.connect('127.0.0.1', 9160)
|
||||
con.set_cql_version('3.0.0')
|
||||
@@ -27,7 +77,50 @@ def get_connection(keyspace, create_missing_keyspace=True):
|
||||
else:
|
||||
raise CQLConnectionError('"{}" is not an existing keyspace'.format(keyspace))
|
||||
|
||||
_conn[keyspace] = con
|
||||
_old_conn[keyspace] = con
|
||||
|
||||
return con
|
||||
|
||||
class connection_manager(object):
|
||||
"""
|
||||
Connection failure tolerant connection manager. Written to be used in a 'with' block for connection pooling
|
||||
"""
|
||||
def __init__(self):
|
||||
if not _hosts:
|
||||
raise CQLConnectionError("No connections have been configured, call cqlengine.connection.setup")
|
||||
self.keyspace = None
|
||||
self.con = _conn
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
pass
|
||||
|
||||
def execute(self, query, params={}):
|
||||
"""
|
||||
Gets a connection from the pool and executes the given query, returns the cursor
|
||||
|
||||
if there's a connection problem, this will silently create a new connection pool
|
||||
from the available hosts, and remove the problematic host from the host list
|
||||
"""
|
||||
global _host_idx
|
||||
|
||||
for i in range(len(_hosts)):
|
||||
try:
|
||||
cur = self.con.cursor()
|
||||
cur.execute(query, params)
|
||||
return cur
|
||||
except TTransportException:
|
||||
#TODO: check for other errors raised in the event of a connection / server problem
|
||||
#move to the next connection and set the connection pool
|
||||
self.con_pool.return_connection(self.con)
|
||||
self.con = None
|
||||
_host_idx += 1
|
||||
_host_idx %= len(_hosts)
|
||||
host = _hosts[_host_idx]
|
||||
_set_conn(host)
|
||||
self.con = _conn
|
||||
|
||||
raise CQLConnectionError("couldn't reach a Cassandra server")
|
||||
|
||||
|
||||
@@ -1,64 +1,62 @@
|
||||
from cqlengine.connection import get_connection
|
||||
from cqlengine.connection import connection_manager
|
||||
|
||||
def create_keyspace(name):
|
||||
con = get_connection(None)
|
||||
cur = con.cursor()
|
||||
cur.execute("""CREATE KEYSPACE {}
|
||||
WITH strategy_class = 'SimpleStrategy'
|
||||
AND strategy_options:replication_factor=1;""".format(name))
|
||||
with connection_manager() as con:
|
||||
con.execute("""CREATE KEYSPACE {}
|
||||
WITH strategy_class = 'SimpleStrategy'
|
||||
AND strategy_options:replication_factor=1;""".format(name))
|
||||
|
||||
def delete_keyspace(name):
|
||||
con = get_connection(None)
|
||||
cur = con.cursor()
|
||||
cur.execute("DROP KEYSPACE {}".format(name))
|
||||
with connection_manager() as con:
|
||||
con.execute("DROP KEYSPACE {}".format(name))
|
||||
|
||||
def create_column_family(model):
|
||||
#construct query string
|
||||
cf_name = model.column_family_name()
|
||||
raw_cf_name = model.column_family_name(include_keyspace=False)
|
||||
|
||||
conn = get_connection(model.keyspace)
|
||||
cur = conn.cursor()
|
||||
with connection_manager() as con:
|
||||
#check for an existing column family
|
||||
ks_info = con.con.client.describe_keyspace(model.keyspace)
|
||||
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
|
||||
qs = ['CREATE TABLE {}'.format(cf_name)]
|
||||
|
||||
#check for an existing column family
|
||||
ks_info = conn.client.describe_keyspace(model.keyspace)
|
||||
if not any([cf_name == cf.name for cf in ks_info.cf_defs]):
|
||||
qs = ['CREATE TABLE {}'.format(cf_name)]
|
||||
#add column types
|
||||
pkeys = []
|
||||
qtypes = []
|
||||
def add_column(col):
|
||||
s = '{} {}'.format(col.db_field_name, col.db_type)
|
||||
if col.primary_key: pkeys.append(col.db_field_name)
|
||||
qtypes.append(s)
|
||||
for name, col in model._columns.items():
|
||||
add_column(col)
|
||||
|
||||
#add column types
|
||||
pkeys = []
|
||||
qtypes = []
|
||||
def add_column(col):
|
||||
s = '{} {}'.format(col.db_field_name, col.db_type)
|
||||
if col.primary_key: pkeys.append(col.db_field_name)
|
||||
qtypes.append(s)
|
||||
for name, col in model._columns.items():
|
||||
add_column(col)
|
||||
qtypes.append('PRIMARY KEY ({})'.format(', '.join(pkeys)))
|
||||
|
||||
qtypes.append('PRIMARY KEY ({})'.format(', '.join(pkeys)))
|
||||
|
||||
qs += ['({})'.format(', '.join(qtypes))]
|
||||
qs = ' '.join(qs)
|
||||
|
||||
cur.execute(qs)
|
||||
|
||||
indexes = [c for n,c in model._columns.items() if c.index]
|
||||
if indexes:
|
||||
for column in indexes:
|
||||
#TODO: check for existing index...
|
||||
#can that be determined from the connection client?
|
||||
qs = ['CREATE INDEX {}'.format(column.db_index_name)]
|
||||
qs += ['ON {}'.format(cf_name)]
|
||||
qs += ['({})'.format(column.db_field_name)]
|
||||
qs += ['({})'.format(', '.join(qtypes))]
|
||||
qs = ' '.join(qs)
|
||||
cur.execute(qs)
|
||||
|
||||
con.execute(qs)
|
||||
|
||||
indexes = [c for n,c in model._columns.items() if c.index]
|
||||
if indexes:
|
||||
for column in indexes:
|
||||
#TODO: check for existing index...
|
||||
#can that be determined from the connection client?
|
||||
qs = ['CREATE INDEX {}'.format(column.db_index_name)]
|
||||
qs += ['ON {}'.format(cf_name)]
|
||||
qs += ['({})'.format(column.db_field_name)]
|
||||
qs = ' '.join(qs)
|
||||
con.execute(qs)
|
||||
|
||||
|
||||
def delete_column_family(model):
|
||||
#check that model exists
|
||||
cf_name = model.column_family_name()
|
||||
conn = get_connection(model.keyspace)
|
||||
ks_info = conn.client.describe_keyspace(model.keyspace)
|
||||
if any([cf_name == cf.name for cf in ks_info.cf_defs]):
|
||||
cur = conn.cursor()
|
||||
cur.execute('drop table {};'.format(cf_name))
|
||||
raw_cf_name = model.column_family_name(include_keyspace=False)
|
||||
with connection_manager() as con:
|
||||
ks_info = con.con.client.describe_keyspace(model.keyspace)
|
||||
if any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
|
||||
con.execute('drop table {};'.format(cf_name))
|
||||
|
||||
|
||||
@@ -24,10 +24,8 @@ class BaseModel(object):
|
||||
value_mngr = column.value_manager(self, column, values.get(name, None))
|
||||
self._values[name] = value_mngr
|
||||
|
||||
#TODO: note any absent fields so they're not deleted
|
||||
|
||||
@classmethod
|
||||
def column_family_name(cls):
|
||||
def column_family_name(cls, include_keyspace=True):
|
||||
"""
|
||||
Returns the column family name if it's been defined
|
||||
otherwise, it creates it from the module and class name
|
||||
@@ -38,7 +36,9 @@ class BaseModel(object):
|
||||
cf_name = cf_name.replace('.', '_')
|
||||
#trim to less than 48 characters or cassandra will complain
|
||||
cf_name = cf_name[-48:]
|
||||
return cf_name.lower()
|
||||
cf_name = cf_name.lower()
|
||||
if not include_keyspace: return cf_name
|
||||
return '{}.{}'.format(cls.keyspace, cf_name)
|
||||
|
||||
@property
|
||||
def pk(self):
|
||||
|
||||
@@ -4,6 +4,7 @@ from hashlib import md5
|
||||
from time import time
|
||||
|
||||
from cqlengine.connection import get_connection
|
||||
from cqlengine.connection import connection_manager
|
||||
from cqlengine.exceptions import CQLEngineException
|
||||
|
||||
#CQL 3 reference:
|
||||
@@ -220,7 +221,7 @@ class QuerySet(object):
|
||||
def __iter__(self):
|
||||
#TODO: cache results
|
||||
if self._cursor is None:
|
||||
conn = get_connection(self.model.keyspace)
|
||||
conn = get_connection()
|
||||
self._cursor = conn.cursor()
|
||||
self._cursor.execute(self._select_query(), self._where_values())
|
||||
self._rowcount = self._cursor.rowcount
|
||||
@@ -350,10 +351,9 @@ class QuerySet(object):
|
||||
qs += ['WHERE {}'.format(self._where_clause())]
|
||||
qs = ' '.join(qs)
|
||||
|
||||
con = get_connection(self.model.keyspace)
|
||||
cur = con.cursor()
|
||||
cur.execute(qs, self._where_values())
|
||||
return cur.fetchone()[0]
|
||||
with connection_manager() as con:
|
||||
cur = con.execute(qs, self._where_values())
|
||||
return cur.fetchone()[0]
|
||||
|
||||
def limit(self, v):
|
||||
"""
|
||||
@@ -430,11 +430,10 @@ class QuerySet(object):
|
||||
qs += ["({})".format(', '.join([':'+f for f in field_names]))]
|
||||
qs = ' '.join(qs)
|
||||
|
||||
conn = get_connection(self.model.keyspace)
|
||||
cur = conn.cursor()
|
||||
cur.execute(qs, field_values)
|
||||
with connection_manager() as con:
|
||||
con.execute(qs, field_values)
|
||||
|
||||
#TODO: delete deleted / nulled columns
|
||||
#delete deleted / nulled columns
|
||||
deleted = [k for k,v in instance._values.items() if v.deleted]
|
||||
if deleted:
|
||||
del_fields = [self.model._columns[f] for f in deleted]
|
||||
@@ -448,11 +447,11 @@ class QuerySet(object):
|
||||
qs = ' '.join(qs)
|
||||
|
||||
pk_dict = dict([(v.db_field_name, getattr(instance, k)) for k,v in pks.items()])
|
||||
cur.execute(qs, pk_dict)
|
||||
|
||||
with connection_manager() as con:
|
||||
con.execute(qs, pk_dict)
|
||||
|
||||
|
||||
|
||||
|
||||
def create(self, **kwargs):
|
||||
return self.model(**kwargs).save()
|
||||
|
||||
@@ -466,9 +465,8 @@ class QuerySet(object):
|
||||
qs += ['WHERE {}'.format(self._where_clause())]
|
||||
qs = ' '.join(qs)
|
||||
|
||||
con = get_connection(self.model.keyspace)
|
||||
cur = con.cursor()
|
||||
cur.execute(qs, self._where_values())
|
||||
with connection_manager() as con:
|
||||
con.execute(qs, self._where_values())
|
||||
|
||||
|
||||
def delete_instance(self, instance):
|
||||
@@ -478,8 +476,7 @@ class QuerySet(object):
|
||||
qs += ['WHERE {0}=:{0}'.format(pk_name)]
|
||||
qs = ' '.join(qs)
|
||||
|
||||
conn = get_connection(self.model.keyspace)
|
||||
cur = conn.cursor()
|
||||
cur.execute(qs, {pk_name:instance.pk})
|
||||
with connection_manager() as con:
|
||||
con.execute(qs, {pk_name:instance.pk})
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
from unittest import TestCase
|
||||
from cqlengine import connection
|
||||
|
||||
class BaseCassEngTestCase(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(BaseCassEngTestCase, cls).setUpClass()
|
||||
if not connection._hosts:
|
||||
connection.setup(['localhost'])
|
||||
|
||||
def assertHasAttr(self, obj, attr):
|
||||
self.assertTrue(hasattr(obj, attr),
|
||||
"{} doesn't have attribute: {}".format(obj, attr))
|
||||
|
||||
0
cqlengine/tests/management/__init__.py
Normal file
0
cqlengine/tests/management/__init__.py
Normal file
0
cqlengine/tests/management/test_management.py
Normal file
0
cqlengine/tests/management/test_management.py
Normal file
Reference in New Issue
Block a user