adding user configurable keyspace
This commit is contained in:
@@ -23,7 +23,7 @@ _username = None
|
||||
_password = None
|
||||
_max_connections = 10
|
||||
|
||||
def setup(hosts, username=None, password=None, max_connections=10):
|
||||
def setup(hosts, username=None, password=None, max_connections=10, default_keyspace=None):
|
||||
"""
|
||||
Records the hosts and connects to one of them
|
||||
|
||||
@@ -36,7 +36,11 @@ def setup(hosts, username=None, password=None, max_connections=10):
|
||||
_username = username
|
||||
_password = password
|
||||
_max_connections = max_connections
|
||||
|
||||
|
||||
if default_keyspace:
|
||||
from cqlengine import models
|
||||
models.DEFAULT_KEYSPACE = default_keyspace
|
||||
|
||||
for host in hosts:
|
||||
host = host.strip()
|
||||
host = host.split(':')
|
||||
|
||||
@@ -14,7 +14,8 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3,
|
||||
:param **replication_values: 1.2 only, additional values to ad to the replication data map
|
||||
"""
|
||||
with connection_manager() as con:
|
||||
if name not in [k.name for k in con.con.client.describe_keyspaces()]:
|
||||
if not any([name == k.name for k in con.con.client.describe_keyspaces()]):
|
||||
# if name not in [k.name for k in con.con.client.describe_keyspaces()]:
|
||||
try:
|
||||
#Try the 1.1 method
|
||||
con.execute("""CREATE KEYSPACE {}
|
||||
@@ -50,11 +51,11 @@ def create_table(model, create_missing_keyspace=True):
|
||||
|
||||
#create missing keyspace
|
||||
if create_missing_keyspace:
|
||||
create_keyspace(model.keyspace)
|
||||
create_keyspace(model._get_keyspace())
|
||||
|
||||
with connection_manager() as con:
|
||||
#check for an existing column family
|
||||
ks_info = con.con.client.describe_keyspace(model.keyspace)
|
||||
ks_info = con.con.client.describe_keyspace(model._get_keyspace())
|
||||
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
|
||||
qs = ['CREATE TABLE {}'.format(cf_name)]
|
||||
|
||||
@@ -85,7 +86,7 @@ def create_table(model, create_missing_keyspace=True):
|
||||
raise
|
||||
|
||||
#get existing index names, skip ones that already exist
|
||||
ks_info = con.con.client.describe_keyspace(model.keyspace)
|
||||
ks_info = con.con.client.describe_keyspace(model._get_keyspace())
|
||||
cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name]
|
||||
idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else []
|
||||
idx_names = filter(None, idx_names)
|
||||
|
||||
@@ -8,6 +8,8 @@ from cqlengine.query import QuerySet, QueryException, DMLQuery
|
||||
|
||||
class ModelDefinitionException(ModelException): pass
|
||||
|
||||
DEFAULT_KEYSPACE = 'cqlengine'
|
||||
|
||||
class hybrid_classmethod(object):
|
||||
"""
|
||||
Allows a method to behave as both a class method and
|
||||
@@ -37,7 +39,7 @@ class BaseModel(object):
|
||||
table_name = None
|
||||
|
||||
#the keyspace for this model
|
||||
keyspace = 'cqlengine'
|
||||
keyspace = None
|
||||
read_repair_chance = 0.1
|
||||
|
||||
def __init__(self, **values):
|
||||
@@ -64,6 +66,11 @@ class BaseModel(object):
|
||||
pks = self._primary_keys.keys()
|
||||
return all([not self._values[k].changed for k in self._primary_keys])
|
||||
|
||||
@classmethod
|
||||
def _get_keyspace(cls):
|
||||
""" Returns the manual keyspace, if set, otherwise the default keyspace """
|
||||
return cls.keyspace or DEFAULT_KEYSPACE
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.as_dict() == other.as_dict()
|
||||
|
||||
@@ -93,7 +100,7 @@ class BaseModel(object):
|
||||
cf_name = cf_name.lower()
|
||||
cf_name = re.sub(r'^_+', '', cf_name)
|
||||
if not include_keyspace: return cf_name
|
||||
return '{}.{}'.format(cls.keyspace, cf_name)
|
||||
return '{}.{}'.format(cls._get_keyspace(), cf_name)
|
||||
|
||||
@property
|
||||
def pk(self):
|
||||
|
||||
@@ -184,7 +184,6 @@ class QuerySet(object):
|
||||
def __init__(self, model):
|
||||
super(QuerySet, self).__init__()
|
||||
self.model = model
|
||||
self.column_family_name = self.model.column_family_name()
|
||||
|
||||
#Where clause filters
|
||||
self._where = []
|
||||
@@ -210,6 +209,10 @@ class QuerySet(object):
|
||||
|
||||
self._batch = None
|
||||
|
||||
@property
|
||||
def column_family_name(self):
|
||||
return self.model.column_family_name()
|
||||
|
||||
def __unicode__(self):
|
||||
return self._select_query()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ class BaseCassEngTestCase(TestCase):
|
||||
def setUpClass(cls):
|
||||
super(BaseCassEngTestCase, cls).setUpClass()
|
||||
if not connection._hosts:
|
||||
connection.setup(['localhost'])
|
||||
connection.setup(['localhost'], default_keyspace='cqlengine_test')
|
||||
|
||||
def assertHasAttr(self, obj, attr):
|
||||
self.assertTrue(hasattr(obj, attr),
|
||||
|
||||
Reference in New Issue
Block a user