Add ContextQuery to allow switching the model keyspace easily

This commit is contained in:
Alan Boudreault
2016-07-27 12:03:46 -04:00
parent 957b6c474b
commit c023d0181e
3 changed files with 150 additions and 5 deletions

View File

@@ -21,7 +21,7 @@ import warnings
from cassandra import metadata
from cassandra.cqlengine import CQLEngineException
from cassandra.cqlengine import columns
from cassandra.cqlengine import columns, query
from cassandra.cqlengine.connection import execute, get_cluster
from cassandra.cqlengine.models import Model
from cassandra.cqlengine.named import NamedTable
@@ -119,9 +119,9 @@ def _get_index_name_by_column(table, column_name):
return index_metadata.name
def sync_table(model):
def sync_table(model, keyspaces=None):
"""
Inspects the model and creates / updates the corresponding table and columns.
Inspects the model and creates / updates the corresponding table and columns for all keyspaces.
Any User Defined Types used in the table are implicitly synchronized.
@@ -135,6 +135,20 @@ def sync_table(model):
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
"""
if keyspaces:
if not isinstance(keyspaces, (list, tuple)):
raise ValueError('keyspaces must be a list or a tuple.')
for keyspace in keyspaces:
with query.ContextQuery(model, keyspace=keyspace) as m:
_sync_table(m)
else:
_sync_table(model)
def _sync_table(model):
if not _allow_schema_modification():
return
@@ -431,15 +445,27 @@ def _update_options(model):
return False
def drop_table(model):
def drop_table(model, keyspaces=None):
"""
Drops the table indicated by the model, if it exists.
Drops the table indicated by the model, if it exists, for all keyspaces.
**This function should be used with caution, especially in production environments.
Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).**
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
"""
if keyspaces:
if not isinstance(keyspaces, (list, tuple)):
raise ValueError('keyspaces must be a list or a tuple.')
for keyspace in keyspaces:
with query.ContextQuery(model, keyspace=keyspace) as m:
_drop_table(m)
else:
_drop_table(model)
def _drop_table(model):
if not _allow_schema_modification():
return

View File

@@ -259,6 +259,26 @@ class BatchQuery(object):
self.execute()
class ContextQuery(object):
def __init__(self, model, keyspace=None):
from cassandra.cqlengine import models
if not issubclass(model, models.Model):
raise CQLEngineException("Models must be derived from base Model.")
ks = keyspace if keyspace else model.__keyspace__
new_type = type(model.__name__, (model,), {'__keyspace__': ks})
self.model = new_type
def __enter__(self):
return self.model
def __exit__(self, exc_type, exc_val, exc_tb):
return
class AbstractQuerySet(object):
def __init__(self, model):

View File

@@ -0,0 +1,99 @@
# Copyright 2013-2016 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cassandra.cqlengine import columns
from cassandra.cqlengine.management import drop_keyspace, sync_table, create_keyspace_simple
from cassandra.cqlengine.models import Model
from cassandra.cqlengine.query import ContextQuery
from tests.integration.cqlengine.base import BaseCassEngTestCase
class TestModel(Model):
__keyspace__ = 'ks1'
partition = columns.Integer(primary_key=True)
cluster = columns.Integer(primary_key=True)
count = columns.Integer()
text = columns.Text()
class ContextQueryTests(BaseCassEngTestCase):
KEYSPACES = ('ks1', 'ks2', 'ks3', 'ks4')
@classmethod
def setUpClass(cls):
super(ContextQueryTests, cls).setUpClass()
for ks in cls.KEYSPACES:
create_keyspace_simple(ks, 1)
sync_table(TestModel, keyspaces=cls.KEYSPACES)
@classmethod
def tearDownClass(cls):
super(ContextQueryTests, cls).tearDownClass()
for ks in cls.KEYSPACES:
drop_keyspace(ks)
def setUp(self):
super(ContextQueryTests, self).setUp()
for ks in self.KEYSPACES:
with ContextQuery(TestModel, keyspace=ks) as tm:
for obj in tm.all():
obj.delete()
def test_context_manager(self):
for ks in self.KEYSPACES:
with ContextQuery(TestModel, keyspace=ks) as tm:
self.assertEqual(tm.__keyspace__, ks)
def test_default_keyspace(self):
# model keyspace write/read
for i in range(5):
TestModel.objects.create(partition=i, cluster=i)
with ContextQuery(TestModel) as tm:
self.assertEqual(5, len(tm.objects.all()))
with ContextQuery(TestModel, keyspace='ks1') as tm:
self.assertEqual(5, len(tm.objects.all()))
for ks in self.KEYSPACES[1:]:
with ContextQuery(TestModel, keyspace=ks) as tm:
self.assertEqual(0, len(tm.objects.all()))
def test_context_keyspace(self):
for i in range(5):
with ContextQuery(TestModel, keyspace='ks4') as tm:
tm.objects.create(partition=i, cluster=i)
with ContextQuery(TestModel, keyspace='ks4') as tm:
self.assertEqual(5, len(tm.objects.all()))
self.assertEqual(0, len(TestModel.objects.all()))
for ks in self.KEYSPACES[:2]:
with ContextQuery(TestModel, keyspace=ks) as tm:
self.assertEqual(0, len(tm.objects.all()))
# simple data update
with ContextQuery(TestModel, keyspace='ks4') as tm:
obj = tm.objects.get(partition=1)
obj.update(count=42)
self.assertEqual(42, tm.objects.get(partition=1).count)