diff --git a/cqlengine/connection.py b/cqlengine/connection.py index 52bc5cde..9be79225 100644 --- a/cqlengine/connection.py +++ b/cqlengine/connection.py @@ -23,7 +23,7 @@ _username = None _password = None _max_connections = 10 -def setup(hosts, username=None, password=None, max_connections=10): +def setup(hosts, username=None, password=None, max_connections=10, default_keyspace=None): """ Records the hosts and connects to one of them @@ -36,7 +36,11 @@ def setup(hosts, username=None, password=None, max_connections=10): _username = username _password = password _max_connections = max_connections - + + if default_keyspace: + from cqlengine import models + models.DEFAULT_KEYSPACE = default_keyspace + for host in hosts: host = host.strip() host = host.split(':') diff --git a/cqlengine/management.py b/cqlengine/management.py index ba8b8145..5ef536cf 100644 --- a/cqlengine/management.py +++ b/cqlengine/management.py @@ -14,7 +14,8 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, :param **replication_values: 1.2 only, additional values to ad to the replication data map """ with connection_manager() as con: - if name not in [k.name for k in con.con.client.describe_keyspaces()]: + if not any([name == k.name for k in con.con.client.describe_keyspaces()]): +# if name not in [k.name for k in con.con.client.describe_keyspaces()]: try: #Try the 1.1 method con.execute("""CREATE KEYSPACE {} @@ -50,11 +51,11 @@ def create_table(model, create_missing_keyspace=True): #create missing keyspace if create_missing_keyspace: - create_keyspace(model.keyspace) + create_keyspace(model._get_keyspace()) with connection_manager() as con: #check for an existing column family - ks_info = con.con.client.describe_keyspace(model.keyspace) + ks_info = con.con.client.describe_keyspace(model._get_keyspace()) if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]): qs = ['CREATE TABLE {}'.format(cf_name)] @@ -85,7 +86,7 @@ def create_table(model, create_missing_keyspace=True): raise #get existing index names, skip ones that already exist - ks_info = con.con.client.describe_keyspace(model.keyspace) + ks_info = con.con.client.describe_keyspace(model._get_keyspace()) cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name] idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else [] idx_names = filter(None, idx_names) diff --git a/cqlengine/models.py b/cqlengine/models.py index 9b6ec3b0..b9801bf0 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -8,6 +8,8 @@ from cqlengine.query import QuerySet, QueryException, DMLQuery class ModelDefinitionException(ModelException): pass +DEFAULT_KEYSPACE = 'cqlengine' + class hybrid_classmethod(object): """ Allows a method to behave as both a class method and @@ -37,7 +39,7 @@ class BaseModel(object): table_name = None #the keyspace for this model - keyspace = 'cqlengine' + keyspace = None read_repair_chance = 0.1 def __init__(self, **values): @@ -64,6 +66,11 @@ class BaseModel(object): pks = self._primary_keys.keys() return all([not self._values[k].changed for k in self._primary_keys]) + @classmethod + def _get_keyspace(cls): + """ Returns the manual keyspace, if set, otherwise the default keyspace """ + return cls.keyspace or DEFAULT_KEYSPACE + def __eq__(self, other): return self.as_dict() == other.as_dict() @@ -93,7 +100,7 @@ class BaseModel(object): cf_name = cf_name.lower() cf_name = re.sub(r'^_+', '', cf_name) if not include_keyspace: return cf_name - return '{}.{}'.format(cls.keyspace, cf_name) + return '{}.{}'.format(cls._get_keyspace(), cf_name) @property def pk(self): diff --git a/cqlengine/query.py b/cqlengine/query.py index 29ab1396..36909f1e 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -184,7 +184,6 @@ class QuerySet(object): def __init__(self, model): super(QuerySet, self).__init__() self.model = model - self.column_family_name = self.model.column_family_name() #Where clause filters self._where = [] @@ -210,6 +209,10 @@ class QuerySet(object): self._batch = None + @property + def column_family_name(self): + return self.model.column_family_name() + def __unicode__(self): return self._select_query() diff --git a/cqlengine/tests/base.py b/cqlengine/tests/base.py index d94ea69b..0b450cdd 100644 --- a/cqlengine/tests/base.py +++ b/cqlengine/tests/base.py @@ -7,7 +7,7 @@ class BaseCassEngTestCase(TestCase): def setUpClass(cls): super(BaseCassEngTestCase, cls).setUpClass() if not connection._hosts: - connection.setup(['localhost']) + connection.setup(['localhost'], default_keyspace='cqlengine_test') def assertHasAttr(self, obj, attr): self.assertTrue(hasattr(obj, attr),