Add ContextQuery to allow switching the model keyspace easily
This commit is contained in:
@@ -21,7 +21,7 @@ import warnings
|
||||
|
||||
from cassandra import metadata
|
||||
from cassandra.cqlengine import CQLEngineException
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine import columns, query
|
||||
from cassandra.cqlengine.connection import execute, get_cluster
|
||||
from cassandra.cqlengine.models import Model
|
||||
from cassandra.cqlengine.named import NamedTable
|
||||
@@ -119,9 +119,9 @@ def _get_index_name_by_column(table, column_name):
|
||||
return index_metadata.name
|
||||
|
||||
|
||||
def sync_table(model):
|
||||
def sync_table(model, keyspaces=None):
|
||||
"""
|
||||
Inspects the model and creates / updates the corresponding table and columns.
|
||||
Inspects the model and creates / updates the corresponding table and columns for all keyspaces.
|
||||
|
||||
Any User Defined Types used in the table are implicitly synchronized.
|
||||
|
||||
@@ -135,6 +135,20 @@ def sync_table(model):
|
||||
|
||||
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
|
||||
"""
|
||||
|
||||
if keyspaces:
|
||||
if not isinstance(keyspaces, (list, tuple)):
|
||||
raise ValueError('keyspaces must be a list or a tuple.')
|
||||
|
||||
for keyspace in keyspaces:
|
||||
with query.ContextQuery(model, keyspace=keyspace) as m:
|
||||
_sync_table(m)
|
||||
else:
|
||||
_sync_table(model)
|
||||
|
||||
|
||||
def _sync_table(model):
|
||||
|
||||
if not _allow_schema_modification():
|
||||
return
|
||||
|
||||
@@ -431,15 +445,27 @@ def _update_options(model):
|
||||
return False
|
||||
|
||||
|
||||
def drop_table(model):
|
||||
def drop_table(model, keyspaces=None):
|
||||
"""
|
||||
Drops the table indicated by the model, if it exists.
|
||||
Drops the table indicated by the model, if it exists, for all keyspaces.
|
||||
|
||||
**This function should be used with caution, especially in production environments.
|
||||
Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).**
|
||||
|
||||
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
|
||||
"""
|
||||
|
||||
if keyspaces:
|
||||
if not isinstance(keyspaces, (list, tuple)):
|
||||
raise ValueError('keyspaces must be a list or a tuple.')
|
||||
|
||||
for keyspace in keyspaces:
|
||||
with query.ContextQuery(model, keyspace=keyspace) as m:
|
||||
_drop_table(m)
|
||||
else:
|
||||
_drop_table(model)
|
||||
|
||||
def _drop_table(model):
|
||||
if not _allow_schema_modification():
|
||||
return
|
||||
|
||||
|
||||
@@ -259,6 +259,26 @@ class BatchQuery(object):
|
||||
self.execute()
|
||||
|
||||
|
||||
class ContextQuery(object):
|
||||
|
||||
def __init__(self, model, keyspace=None):
|
||||
from cassandra.cqlengine import models
|
||||
|
||||
if not issubclass(model, models.Model):
|
||||
raise CQLEngineException("Models must be derived from base Model.")
|
||||
|
||||
ks = keyspace if keyspace else model.__keyspace__
|
||||
new_type = type(model.__name__, (model,), {'__keyspace__': ks})
|
||||
|
||||
self.model = new_type
|
||||
|
||||
def __enter__(self):
|
||||
return self.model
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return
|
||||
|
||||
|
||||
class AbstractQuerySet(object):
|
||||
|
||||
def __init__(self, model):
|
||||
|
||||
99
tests/integration/cqlengine/test_context_query.py
Normal file
99
tests/integration/cqlengine/test_context_query.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright 2013-2016 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine.management import drop_keyspace, sync_table, create_keyspace_simple
|
||||
from cassandra.cqlengine.models import Model
|
||||
from cassandra.cqlengine.query import ContextQuery
|
||||
from tests.integration.cqlengine.base import BaseCassEngTestCase
|
||||
|
||||
|
||||
class TestModel(Model):
|
||||
|
||||
__keyspace__ = 'ks1'
|
||||
|
||||
partition = columns.Integer(primary_key=True)
|
||||
cluster = columns.Integer(primary_key=True)
|
||||
count = columns.Integer()
|
||||
text = columns.Text()
|
||||
|
||||
|
||||
class ContextQueryTests(BaseCassEngTestCase):
|
||||
|
||||
KEYSPACES = ('ks1', 'ks2', 'ks3', 'ks4')
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(ContextQueryTests, cls).setUpClass()
|
||||
for ks in cls.KEYSPACES:
|
||||
create_keyspace_simple(ks, 1)
|
||||
sync_table(TestModel, keyspaces=cls.KEYSPACES)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(ContextQueryTests, cls).tearDownClass()
|
||||
for ks in cls.KEYSPACES:
|
||||
drop_keyspace(ks)
|
||||
|
||||
def setUp(self):
|
||||
super(ContextQueryTests, self).setUp()
|
||||
for ks in self.KEYSPACES:
|
||||
with ContextQuery(TestModel, keyspace=ks) as tm:
|
||||
for obj in tm.all():
|
||||
obj.delete()
|
||||
|
||||
def test_context_manager(self):
|
||||
|
||||
for ks in self.KEYSPACES:
|
||||
with ContextQuery(TestModel, keyspace=ks) as tm:
|
||||
self.assertEqual(tm.__keyspace__, ks)
|
||||
|
||||
def test_default_keyspace(self):
|
||||
|
||||
# model keyspace write/read
|
||||
for i in range(5):
|
||||
TestModel.objects.create(partition=i, cluster=i)
|
||||
|
||||
with ContextQuery(TestModel) as tm:
|
||||
self.assertEqual(5, len(tm.objects.all()))
|
||||
|
||||
with ContextQuery(TestModel, keyspace='ks1') as tm:
|
||||
self.assertEqual(5, len(tm.objects.all()))
|
||||
|
||||
for ks in self.KEYSPACES[1:]:
|
||||
with ContextQuery(TestModel, keyspace=ks) as tm:
|
||||
self.assertEqual(0, len(tm.objects.all()))
|
||||
|
||||
def test_context_keyspace(self):
|
||||
|
||||
for i in range(5):
|
||||
with ContextQuery(TestModel, keyspace='ks4') as tm:
|
||||
tm.objects.create(partition=i, cluster=i)
|
||||
|
||||
with ContextQuery(TestModel, keyspace='ks4') as tm:
|
||||
self.assertEqual(5, len(tm.objects.all()))
|
||||
|
||||
self.assertEqual(0, len(TestModel.objects.all()))
|
||||
|
||||
for ks in self.KEYSPACES[:2]:
|
||||
with ContextQuery(TestModel, keyspace=ks) as tm:
|
||||
self.assertEqual(0, len(tm.objects.all()))
|
||||
|
||||
# simple data update
|
||||
with ContextQuery(TestModel, keyspace='ks4') as tm:
|
||||
obj = tm.objects.get(partition=1)
|
||||
obj.update(count=42)
|
||||
|
||||
self.assertEqual(42, tm.objects.get(partition=1).count)
|
||||
|
||||
Reference in New Issue
Block a user