diff --git a/cqlengine/models.py b/cqlengine/models.py index d03a0041..2c011b8e 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -7,10 +7,12 @@ from cqlengine.query import ModelQuerySet, DMLQuery, AbstractQueryableColumn from cqlengine.query import DoesNotExist as _DoesNotExist from cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned + class ModelDefinitionException(ModelException): pass DEFAULT_KEYSPACE = 'cqlengine' + class hybrid_classmethod(object): """ Allows a method to behave as both a class method and @@ -44,7 +46,7 @@ class QuerySetDescriptor(object): """ :rtype: ModelQuerySet """ if model.__abstract__: raise CQLEngineException('cannot execute queries against abstract models') - return ModelQuerySet(model) + return model.__queryset__(model) def __call__(self, *args, **kwargs): """ @@ -140,6 +142,10 @@ class BaseModel(object): #the keyspace for this model __keyspace__ = None + # the queryset class used for this class + __queryset__ = ModelQuerySet + __dmlquery__ = DMLQuery + __read_repair_chance__ = 0.1 def __init__(self, **values): @@ -261,7 +267,7 @@ class BaseModel(object): def save(self): is_new = self.pk is None self.validate() - DMLQuery(self.__class__, self, batch=self._batch).save() + self.__dmlquery__(self.__class__, self, batch=self._batch).save() #reset the value managers for v in self._values.values(): @@ -272,7 +278,7 @@ class BaseModel(object): def delete(self): """ Deletes this instance """ - DMLQuery(self.__class__, self, batch=self._batch).delete() + self.__dmlquery__(self.__class__, self, batch=self._batch).delete() @classmethod def _class_batch(cls, batch): diff --git a/cqlengine/tests/model/test_class_construction.py b/cqlengine/tests/model/test_class_construction.py index f0344c6b..6653946b 100644 --- a/cqlengine/tests/model/test_class_construction.py +++ b/cqlengine/tests/model/test_class_construction.py @@ -1,5 +1,5 @@ from uuid import uuid4 -from cqlengine.query import QueryException +from cqlengine.query import QueryException, ModelQuerySet, DMLQuery from cqlengine.tests.base import BaseCassEngTestCase from cqlengine.exceptions import ModelException, CQLEngineException @@ -300,6 +300,42 @@ class TestAbstractModelClasses(BaseCassEngTestCase): delete_table(ConcreteModelWithCol) +class TestCustomQuerySet(BaseCassEngTestCase): + """ Tests overriding the default queryset class """ + + class TestException(Exception): pass + + def test_overriding_queryset(self): + + class QSet(ModelQuerySet): + def create(iself, **kwargs): + raise self.TestException + + class CQModel(Model): + __queryset__ = QSet + part = columns.UUID(primary_key=True) + data = columns.Text() + + with self.assertRaises(self.TestException): + CQModel.create(part=uuid4(), data='s') + + def test_overriding_dmlqueryset(self): + + class DMLQ(DMLQuery): + def save(iself): + raise self.TestException + + class CDQModel(Model): + __dmlquery__ = DMLQ + part = columns.UUID(primary_key=True) + data = columns.Text() + + with self.assertRaises(self.TestException): + CDQModel().save() + + + +