adding dynamic columns to models (which aren't working yet)
working on the QuerySet class adding additional tests around model saving and loading
This commit is contained in:
@@ -14,7 +14,7 @@ class BaseColumn(object):
|
|||||||
:param primary_key: bool flag, there can be only one primary key per doc
|
:param primary_key: bool flag, there can be only one primary key per doc
|
||||||
:param db_field: the fieldname this field will map to in the database
|
:param db_field: the fieldname this field will map to in the database
|
||||||
:param default: the default value, can be a value or a callable (no args)
|
:param default: the default value, can be a value or a callable (no args)
|
||||||
:param null: bool, is the field nullable?
|
:param null: boolean, is the field nullable?
|
||||||
"""
|
"""
|
||||||
self.primary_key = primary_key
|
self.primary_key = primary_key
|
||||||
self.db_field = db_field
|
self.db_field = db_field
|
||||||
|
|||||||
@@ -25,10 +25,6 @@ class BaseModel(object):
|
|||||||
if k not in values:
|
if k not in values:
|
||||||
setattr(self, k, None)
|
setattr(self, k, None)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _column_family_definition(cls):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find(cls, pk):
|
def find(cls, pk):
|
||||||
""" Loads a document by it's primary key """
|
""" Loads a document by it's primary key """
|
||||||
@@ -39,6 +35,16 @@ class BaseModel(object):
|
|||||||
""" Returns the object's primary key, regardless of it's name """
|
""" Returns the object's primary key, regardless of it's name """
|
||||||
return getattr(self, self._pk_name)
|
return getattr(self, self._pk_name)
|
||||||
|
|
||||||
|
#dynamic column methods
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._dynamic_columns[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key, val):
|
||||||
|
self._dynamic_columns[key] = val
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self._dynamic_columns[key]
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
""" Cleans and validates the field values """
|
""" Cleans and validates the field values """
|
||||||
for name, col in self._columns.items():
|
for name, col in self._columns.items():
|
||||||
@@ -47,11 +53,9 @@ class BaseModel(object):
|
|||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
""" Returns a map of column names to cleaned values """
|
""" Returns a map of column names to cleaned values """
|
||||||
values = {}
|
values = self._dynamic_columns or {}
|
||||||
for name, col in self._columns.items():
|
for name, col in self._columns.items():
|
||||||
values[name] = col.to_database(getattr(self, name, None))
|
values[name] = col.to_database(getattr(self, name, None))
|
||||||
|
|
||||||
#TODO: merge in dynamic columns
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
from cassandraengine.connection import get_connection
|
from cassandraengine.connection import get_connection
|
||||||
|
|
||||||
class QuerySet(object):
|
class QuerySet(object):
|
||||||
@@ -5,16 +7,18 @@ class QuerySet(object):
|
|||||||
#TODO: querysets should be executed lazily
|
#TODO: querysets should be executed lazily
|
||||||
#TODO: conflicting filter args should raise exception unless a force kwarg is supplied
|
#TODO: conflicting filter args should raise exception unless a force kwarg is supplied
|
||||||
|
|
||||||
def __init__(self, model, query={}):
|
def __init__(self, model, query_args={}):
|
||||||
super(QuerySet, self).__init__()
|
super(QuerySet, self).__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.query_args = query_args
|
||||||
self.column_family_name = self.model.objects.column_family_name
|
self.column_family_name = self.model.objects.column_family_name
|
||||||
|
|
||||||
self._cursor = None
|
self._cursor = None
|
||||||
|
|
||||||
#----query generation / execution----
|
#----query generation / execution----
|
||||||
def _execute_query(self):
|
def _execute_query(self):
|
||||||
pass
|
conn = get_connection()
|
||||||
|
self._cursor = conn.cursor()
|
||||||
|
|
||||||
def _generate_querystring(self):
|
def _generate_querystring(self):
|
||||||
pass
|
pass
|
||||||
@@ -27,39 +31,61 @@ class QuerySet(object):
|
|||||||
|
|
||||||
#----Reads------
|
#----Reads------
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
pass
|
if self._cursor is None:
|
||||||
|
self._execute_query()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _get_next(self):
|
||||||
|
"""
|
||||||
|
Gets the next cursor result
|
||||||
|
Returns a db_field->value dict
|
||||||
|
"""
|
||||||
|
cur = self._cursor
|
||||||
|
values = cur.fetchone()
|
||||||
|
if values is None: return None
|
||||||
|
names = [i[0] for i in cur.description]
|
||||||
|
value_dict = dict(zip(names, values))
|
||||||
|
return value_dict
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
pass
|
values = self._get_next()
|
||||||
|
if values is None: raise StopIteration
|
||||||
|
return values
|
||||||
|
|
||||||
def first(self):
|
def first(self):
|
||||||
conn = get_connection()
|
|
||||||
cur = conn.cursor()
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def all(self):
|
def all(self):
|
||||||
pass
|
return QuerySet(self.model)
|
||||||
|
|
||||||
def filter(self, **kwargs):
|
def filter(self, **kwargs):
|
||||||
pass
|
qargs = copy.deepcopy(self.query_args)
|
||||||
|
qargs.update(kwargs)
|
||||||
|
return QuerySet(self.model, query_args=qargs)
|
||||||
|
|
||||||
def exclude(self, **kwargs):
|
def exclude(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Need to invert the logic for all kwargs
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def count(self):
|
||||||
|
"""
|
||||||
|
Returns the number of rows matched by this query
|
||||||
|
"""
|
||||||
|
|
||||||
def find(self, pk):
|
def find(self, pk):
|
||||||
"""
|
"""
|
||||||
loads one document identified by it's primary key
|
loads one document identified by it's primary key
|
||||||
"""
|
"""
|
||||||
|
#TODO: make this a convenience wrapper of the filter method
|
||||||
qs = 'SELECT * FROM {column_family} WHERE {pk_name}=:{pk_name}'
|
qs = 'SELECT * FROM {column_family} WHERE {pk_name}=:{pk_name}'
|
||||||
qs = qs.format(column_family=self.column_family_name,
|
qs = qs.format(column_family=self.column_family_name,
|
||||||
pk_name=self.model._pk_name)
|
pk_name=self.model._pk_name)
|
||||||
conn = get_connection()
|
conn = get_connection()
|
||||||
cur = conn.cursor()
|
self._cursor = conn.cursor()
|
||||||
cur.execute(qs, {self.model._pk_name:pk})
|
self._cursor.execute(qs, {self.model._pk_name:pk})
|
||||||
values = cur.fetchone()
|
return self._get_next()
|
||||||
names = [i[0] for i in cur.description]
|
|
||||||
value_dict = dict(zip(names, values))
|
|
||||||
return value_dict
|
|
||||||
|
|
||||||
|
|
||||||
#----writes----
|
#----writes----
|
||||||
|
|||||||
16
cassandraengine/tests/columns/test_validation.py
Normal file
16
cassandraengine/tests/columns/test_validation.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
#tests the behavior of the column classes
|
||||||
|
|
||||||
|
from cassandraengine.tests.base import BaseCassEngTestCase
|
||||||
|
|
||||||
|
from cassandraengine.columns import BaseColumn
|
||||||
|
from cassandraengine.columns import Bytes
|
||||||
|
from cassandraengine.columns import Ascii
|
||||||
|
from cassandraengine.columns import Text
|
||||||
|
from cassandraengine.columns import Integer
|
||||||
|
from cassandraengine.columns import DateTime
|
||||||
|
from cassandraengine.columns import UUID
|
||||||
|
from cassandraengine.columns import Boolean
|
||||||
|
from cassandraengine.columns import Float
|
||||||
|
from cassandraengine.columns import Decimal
|
||||||
|
|
||||||
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from unittest import skip
|
||||||
from cassandraengine.tests.base import BaseCassEngTestCase
|
from cassandraengine.tests.base import BaseCassEngTestCase
|
||||||
|
|
||||||
from cassandraengine.models import Model
|
from cassandraengine.models import Model
|
||||||
@@ -7,20 +8,56 @@ class TestModel(Model):
|
|||||||
count = columns.Integer()
|
count = columns.Integer()
|
||||||
text = columns.Text()
|
text = columns.Text()
|
||||||
|
|
||||||
|
#class TestModel2(Model):
|
||||||
|
|
||||||
class TestModelIO(BaseCassEngTestCase):
|
class TestModelIO(BaseCassEngTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TestModelIO, self).setUp()
|
super(TestModelIO, self).setUp()
|
||||||
TestModel.objects._create_column_family()
|
TestModel.objects._create_column_family()
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
super(TestModelIO, self).tearDown()
|
|
||||||
TestModel.objects._delete_column_family()
|
|
||||||
|
|
||||||
def test_model_save_and_load(self):
|
def test_model_save_and_load(self):
|
||||||
|
"""
|
||||||
|
Tests that models can be saved and retrieved
|
||||||
|
"""
|
||||||
tm = TestModel.objects.create(count=8, text='123456789')
|
tm = TestModel.objects.create(count=8, text='123456789')
|
||||||
tm2 = TestModel.objects.find(tm.pk)
|
tm2 = TestModel.objects.find(tm.pk)
|
||||||
|
|
||||||
for cname in tm._columns.keys():
|
for cname in tm._columns.keys():
|
||||||
self.assertEquals(getattr(tm, cname), getattr(tm2, cname))
|
self.assertEquals(getattr(tm, cname), getattr(tm2, cname))
|
||||||
|
|
||||||
|
def test_model_updating_works_properly(self):
|
||||||
|
"""
|
||||||
|
Tests that subsequent saves after initial model creation work
|
||||||
|
"""
|
||||||
|
tm = TestModel.objects.create(count=8, text='123456789')
|
||||||
|
|
||||||
|
tm.count = 100
|
||||||
|
tm.save()
|
||||||
|
|
||||||
|
tm2 = TestModel.objects.find(tm.pk)
|
||||||
|
self.assertEquals(tm.count, tm2.count)
|
||||||
|
|
||||||
|
def test_nullable_columns_are_saved_properly(self):
|
||||||
|
"""
|
||||||
|
Tests that nullable columns save without any trouble
|
||||||
|
"""
|
||||||
|
|
||||||
|
@skip
|
||||||
|
def test_dynamic_columns(self):
|
||||||
|
"""
|
||||||
|
Tests that items put into dynamic columns are saved and retrieved properly
|
||||||
|
|
||||||
|
Note: seems I've misunderstood how arbitrary column names work in Cassandra
|
||||||
|
skipping for now
|
||||||
|
"""
|
||||||
|
#TODO:Fix this
|
||||||
|
tm = TestModel(count=8, text='123456789')
|
||||||
|
tm['other'] = 'something'
|
||||||
|
tm['number'] = 5
|
||||||
|
tm.save()
|
||||||
|
|
||||||
|
tm2 = TestModel.objects.find(tm.pk)
|
||||||
|
self.assertEquals(tm['other'], tm2['other'])
|
||||||
|
self.assertEquals(tm['number'], tm2['number'])
|
||||||
|
|
||||||
|
|||||||
0
cassandraengine/tests/model/test_validation.py
Normal file
0
cassandraengine/tests/model/test_validation.py
Normal file
Reference in New Issue
Block a user