restructuring some of the batch query internals, putting some unit tests around batch querying
This commit is contained in:
@@ -8,6 +8,22 @@ from cqlengine.query import QuerySet, QueryException, DMLQuery
|
||||
|
||||
class ModelDefinitionException(ModelException): pass
|
||||
|
||||
class hybrid_classmethod(object):
|
||||
"""
|
||||
Allows a method to behave as both a class method and
|
||||
normal instance method depending on how it's called
|
||||
"""
|
||||
|
||||
def __init__(self, clsmethod, instmethod):
|
||||
self.clsmethod = clsmethod
|
||||
self.instmethod = instmethod
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self.clsmethod.__get__(owner, owner)
|
||||
else:
|
||||
return self.instmethod.__get__(instance, owner)
|
||||
|
||||
class BaseModel(object):
|
||||
"""
|
||||
The base model class, don't inherit from this, inherit from Model, defined below
|
||||
@@ -35,6 +51,7 @@ class BaseModel(object):
|
||||
# a flag set by the deserializer to indicate
|
||||
# that update should be used when persisting changes
|
||||
self._is_persisted = False
|
||||
self._batch = None
|
||||
|
||||
def _can_update(self):
|
||||
"""
|
||||
@@ -112,13 +129,10 @@ class BaseModel(object):
|
||||
def get(cls, **kwargs):
|
||||
return cls.objects.get(**kwargs)
|
||||
|
||||
def save(self, batch_obj=None):
|
||||
def save(self):
|
||||
is_new = self.pk is None
|
||||
self.validate()
|
||||
if batch_obj:
|
||||
DMLQuery(self.__class__, self).batch(batch_obj).save()
|
||||
else:
|
||||
DMLQuery(self.__class__, self).save()
|
||||
DMLQuery(self.__class__, self, batch=self._batch).save()
|
||||
|
||||
#reset the value managers
|
||||
for v in self._values.values():
|
||||
@@ -130,13 +144,19 @@ class BaseModel(object):
|
||||
|
||||
def delete(self):
|
||||
""" Deletes this instance """
|
||||
DMLQuery(self.__class__, self).delete()
|
||||
DMLQuery(self.__class__, self, batch=self._batch).delete()
|
||||
|
||||
@classmethod
|
||||
def _class_batch(cls, batch):
|
||||
return cls.objects.batch(batch)
|
||||
|
||||
def _inst_batch(self, batch):
|
||||
self._batch = batch
|
||||
return self
|
||||
|
||||
batch = hybrid_classmethod(_class_batch, _inst_batch)
|
||||
|
||||
|
||||
def batch(self, batch_obj):
|
||||
"""
|
||||
Returns a batched DML query
|
||||
"""
|
||||
return DMLQuery(self.__class__, self).batch(batch_obj)
|
||||
|
||||
class ModelMetaClass(type):
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ class BatchQuery(object):
|
||||
Handles the batching of queries
|
||||
"""
|
||||
|
||||
def __init__(self, consistency=Consistency.ONE, timestamp=None):
|
||||
def __init__(self, consistency=None, timestamp=None):
|
||||
self.queries = []
|
||||
self.consistency = consistency
|
||||
if timestamp is not None and not isinstance(timestamp, datetime):
|
||||
@@ -149,18 +149,18 @@ class BatchQuery(object):
|
||||
self.queries.append((query, params))
|
||||
|
||||
def execute(self):
|
||||
query_list = []
|
||||
parameters = {}
|
||||
|
||||
opener = 'BEGIN BATCH USING CONSISTENCY {}'.format(self.consistency)
|
||||
opener = 'BEGIN BATCH'
|
||||
if self.consistency:
|
||||
opener += ' USING CONSISTENCY {}'.format(self.consistency)
|
||||
if self.timestamp:
|
||||
epoch = datetime(1970, 1, 1)
|
||||
ts = long((self.timestamp - epoch).total_seconds() * 1000)
|
||||
opener += ' TIMESTAMP {}'.format(ts)
|
||||
|
||||
query_list = [opener]
|
||||
parameters = {}
|
||||
for query, params in self.queries:
|
||||
query_list.append(query)
|
||||
query_list.append(' ' + query)
|
||||
parameters.update(params)
|
||||
|
||||
query_list.append('APPLY BATCH;')
|
||||
@@ -560,7 +560,7 @@ class QuerySet(object):
|
||||
return self._only_or_defer('defer', fields)
|
||||
|
||||
def create(self, **kwargs):
|
||||
return self.model(**kwargs).save(batch_obj=self._batch)
|
||||
return self.model(**kwargs).batch(self._batch).save()
|
||||
|
||||
#----delete---
|
||||
def delete(self, columns=[]):
|
||||
@@ -590,11 +590,11 @@ class DMLQuery(object):
|
||||
unlike the read query object, this is mutable
|
||||
"""
|
||||
|
||||
def __init__(self, model, instance=None):
|
||||
def __init__(self, model, instance=None, batch=None):
|
||||
self.model = model
|
||||
self.column_family_name = self.model.column_family_name()
|
||||
self.instance = instance
|
||||
self.batch = None
|
||||
self.batch = batch
|
||||
pass
|
||||
|
||||
def batch(self, batch_obj):
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
from unittest import skip
|
||||
from uuid import uuid4
|
||||
import random
|
||||
from cqlengine import Model, columns
|
||||
from cqlengine.management import delete_table, create_table
|
||||
from cqlengine.query import BatchQuery, Consistency
|
||||
from cqlengine.tests.base import BaseCassEngTestCase
|
||||
|
||||
class TestMultiKeyModel(Model):
|
||||
partition = columns.Integer(primary_key=True)
|
||||
cluster = columns.Integer(primary_key=True)
|
||||
count = columns.Integer(required=False)
|
||||
text = columns.Text(required=False)
|
||||
|
||||
class BatchQueryTests(BaseCassEngTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(BatchQueryTests, cls).setUpClass()
|
||||
delete_table(TestMultiKeyModel)
|
||||
create_table(TestMultiKeyModel)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(BatchQueryTests, cls).tearDownClass()
|
||||
delete_table(TestMultiKeyModel)
|
||||
|
||||
def setUp(self):
|
||||
super(BatchQueryTests, self).setUp()
|
||||
self.pkey = 1
|
||||
for obj in TestMultiKeyModel.filter(partition=self.pkey):
|
||||
obj.delete()
|
||||
|
||||
|
||||
|
||||
def test_insert_success_case(self):
|
||||
|
||||
b = BatchQuery()
|
||||
inst = TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=2, count=3, text='4')
|
||||
|
||||
with self.assertRaises(TestMultiKeyModel.DoesNotExist):
|
||||
TestMultiKeyModel.get(partition=self.pkey, cluster=2)
|
||||
|
||||
b.execute()
|
||||
|
||||
TestMultiKeyModel.get(partition=self.pkey, cluster=2)
|
||||
|
||||
def test_update_success_case(self):
|
||||
|
||||
inst = TestMultiKeyModel.create(partition=self.pkey, cluster=2, count=3, text='4')
|
||||
|
||||
b = BatchQuery()
|
||||
|
||||
inst.count = 4
|
||||
inst.batch(b).save()
|
||||
|
||||
inst2 = TestMultiKeyModel.get(partition=self.pkey, cluster=2)
|
||||
assert inst2.count == 3
|
||||
|
||||
b.execute()
|
||||
|
||||
inst3 = TestMultiKeyModel.get(partition=self.pkey, cluster=2)
|
||||
assert inst3.count == 4
|
||||
|
||||
def test_delete_success_case(self):
|
||||
|
||||
inst = TestMultiKeyModel.create(partition=self.pkey, cluster=2, count=3, text='4')
|
||||
|
||||
b = BatchQuery()
|
||||
|
||||
inst.batch(b).delete()
|
||||
|
||||
TestMultiKeyModel.get(partition=self.pkey, cluster=2)
|
||||
|
||||
b.execute()
|
||||
|
||||
with self.assertRaises(TestMultiKeyModel.DoesNotExist):
|
||||
TestMultiKeyModel.get(partition=self.pkey, cluster=2)
|
||||
|
||||
def test_context_manager(self):
|
||||
|
||||
with BatchQuery() as b:
|
||||
for i in range(5):
|
||||
TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=i, count=3, text='4')
|
||||
|
||||
for i in range(5):
|
||||
with self.assertRaises(TestMultiKeyModel.DoesNotExist):
|
||||
TestMultiKeyModel.get(partition=self.pkey, cluster=i)
|
||||
|
||||
for i in range(5):
|
||||
TestMultiKeyModel.get(partition=self.pkey, cluster=i)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user