adding dynamic columns to models (which aren't working yet)

working on the QuerySet class
adding additional tests around model saving and loading
This commit is contained in:
Blake Eggleston
2012-11-10 21:12:47 -08:00
parent cfebd5a32d
commit b6042ac57a
6 changed files with 109 additions and 26 deletions

View File

@@ -14,7 +14,7 @@ class BaseColumn(object):
:param primary_key: bool flag, there can be only one primary key per doc :param primary_key: bool flag, there can be only one primary key per doc
:param db_field: the fieldname this field will map to in the database :param db_field: the fieldname this field will map to in the database
:param default: the default value, can be a value or a callable (no args) :param default: the default value, can be a value or a callable (no args)
:param null: bool, is the field nullable? :param null: boolean, is the field nullable?
""" """
self.primary_key = primary_key self.primary_key = primary_key
self.db_field = db_field self.db_field = db_field

View File

@@ -25,10 +25,6 @@ class BaseModel(object):
if k not in values: if k not in values:
setattr(self, k, None) setattr(self, k, None)
@classmethod
def _column_family_definition(cls):
pass
@classmethod @classmethod
def find(cls, pk): def find(cls, pk):
""" Loads a document by it's primary key """ """ Loads a document by it's primary key """
@@ -39,6 +35,16 @@ class BaseModel(object):
""" Returns the object's primary key, regardless of it's name """ """ Returns the object's primary key, regardless of it's name """
return getattr(self, self._pk_name) return getattr(self, self._pk_name)
#dynamic column methods
def __getitem__(self, key):
return self._dynamic_columns[key]
def __setitem__(self, key, val):
self._dynamic_columns[key] = val
def __delitem__(self, key):
del self._dynamic_columns[key]
def validate(self): def validate(self):
""" Cleans and validates the field values """ """ Cleans and validates the field values """
for name, col in self._columns.items(): for name, col in self._columns.items():
@@ -47,11 +53,9 @@ class BaseModel(object):
def as_dict(self): def as_dict(self):
""" Returns a map of column names to cleaned values """ """ Returns a map of column names to cleaned values """
values = {} values = self._dynamic_columns or {}
for name, col in self._columns.items(): for name, col in self._columns.items():
values[name] = col.to_database(getattr(self, name, None)) values[name] = col.to_database(getattr(self, name, None))
#TODO: merge in dynamic columns
return values return values
def save(self): def save(self):

View File

@@ -1,3 +1,5 @@
import copy
from cassandraengine.connection import get_connection from cassandraengine.connection import get_connection
class QuerySet(object): class QuerySet(object):
@@ -5,16 +7,18 @@ class QuerySet(object):
#TODO: querysets should be executed lazily #TODO: querysets should be executed lazily
#TODO: conflicting filter args should raise exception unless a force kwarg is supplied #TODO: conflicting filter args should raise exception unless a force kwarg is supplied
def __init__(self, model, query={}): def __init__(self, model, query_args={}):
super(QuerySet, self).__init__() super(QuerySet, self).__init__()
self.model = model self.model = model
self.query_args = query_args
self.column_family_name = self.model.objects.column_family_name self.column_family_name = self.model.objects.column_family_name
self._cursor = None self._cursor = None
#----query generation / execution---- #----query generation / execution----
def _execute_query(self): def _execute_query(self):
pass conn = get_connection()
self._cursor = conn.cursor()
def _generate_querystring(self): def _generate_querystring(self):
pass pass
@@ -27,39 +31,61 @@ class QuerySet(object):
#----Reads------ #----Reads------
def __iter__(self): def __iter__(self):
pass if self._cursor is None:
self._execute_query()
return self
def _get_next(self):
"""
Gets the next cursor result
Returns a db_field->value dict
"""
cur = self._cursor
values = cur.fetchone()
if values is None: return None
names = [i[0] for i in cur.description]
value_dict = dict(zip(names, values))
return value_dict
def next(self): def next(self):
pass values = self._get_next()
if values is None: raise StopIteration
return values
def first(self): def first(self):
conn = get_connection()
cur = conn.cursor()
pass pass
def all(self): def all(self):
pass return QuerySet(self.model)
def filter(self, **kwargs): def filter(self, **kwargs):
pass qargs = copy.deepcopy(self.query_args)
qargs.update(kwargs)
return QuerySet(self.model, query_args=qargs)
def exclude(self, **kwargs): def exclude(self, **kwargs):
"""
Need to invert the logic for all kwargs
"""
pass pass
def count(self):
"""
Returns the number of rows matched by this query
"""
def find(self, pk): def find(self, pk):
""" """
loads one document identified by it's primary key loads one document identified by it's primary key
""" """
#TODO: make this a convenience wrapper of the filter method
qs = 'SELECT * FROM {column_family} WHERE {pk_name}=:{pk_name}' qs = 'SELECT * FROM {column_family} WHERE {pk_name}=:{pk_name}'
qs = qs.format(column_family=self.column_family_name, qs = qs.format(column_family=self.column_family_name,
pk_name=self.model._pk_name) pk_name=self.model._pk_name)
conn = get_connection() conn = get_connection()
cur = conn.cursor() self._cursor = conn.cursor()
cur.execute(qs, {self.model._pk_name:pk}) self._cursor.execute(qs, {self.model._pk_name:pk})
values = cur.fetchone() return self._get_next()
names = [i[0] for i in cur.description]
value_dict = dict(zip(names, values))
return value_dict
#----writes---- #----writes----

View File

@@ -0,0 +1,16 @@
#tests the behavior of the column classes
from cassandraengine.tests.base import BaseCassEngTestCase
from cassandraengine.columns import BaseColumn
from cassandraengine.columns import Bytes
from cassandraengine.columns import Ascii
from cassandraengine.columns import Text
from cassandraengine.columns import Integer
from cassandraengine.columns import DateTime
from cassandraengine.columns import UUID
from cassandraengine.columns import Boolean
from cassandraengine.columns import Float
from cassandraengine.columns import Decimal

View File

@@ -1,3 +1,4 @@
from unittest import skip
from cassandraengine.tests.base import BaseCassEngTestCase from cassandraengine.tests.base import BaseCassEngTestCase
from cassandraengine.models import Model from cassandraengine.models import Model
@@ -7,20 +8,56 @@ class TestModel(Model):
count = columns.Integer() count = columns.Integer()
text = columns.Text() text = columns.Text()
#class TestModel2(Model):
class TestModelIO(BaseCassEngTestCase): class TestModelIO(BaseCassEngTestCase):
def setUp(self): def setUp(self):
super(TestModelIO, self).setUp() super(TestModelIO, self).setUp()
TestModel.objects._create_column_family() TestModel.objects._create_column_family()
def tearDown(self):
super(TestModelIO, self).tearDown()
TestModel.objects._delete_column_family()
def test_model_save_and_load(self): def test_model_save_and_load(self):
"""
Tests that models can be saved and retrieved
"""
tm = TestModel.objects.create(count=8, text='123456789') tm = TestModel.objects.create(count=8, text='123456789')
tm2 = TestModel.objects.find(tm.pk) tm2 = TestModel.objects.find(tm.pk)
for cname in tm._columns.keys(): for cname in tm._columns.keys():
self.assertEquals(getattr(tm, cname), getattr(tm2, cname)) self.assertEquals(getattr(tm, cname), getattr(tm2, cname))
def test_model_updating_works_properly(self):
"""
Tests that subsequent saves after initial model creation work
"""
tm = TestModel.objects.create(count=8, text='123456789')
tm.count = 100
tm.save()
tm2 = TestModel.objects.find(tm.pk)
self.assertEquals(tm.count, tm2.count)
def test_nullable_columns_are_saved_properly(self):
"""
Tests that nullable columns save without any trouble
"""
@skip
def test_dynamic_columns(self):
"""
Tests that items put into dynamic columns are saved and retrieved properly
Note: seems I've misunderstood how arbitrary column names work in Cassandra
skipping for now
"""
#TODO:Fix this
tm = TestModel(count=8, text='123456789')
tm['other'] = 'something'
tm['number'] = 5
tm.save()
tm2 = TestModel.objects.find(tm.pk)
self.assertEquals(tm['other'], tm2['other'])
self.assertEquals(tm['number'], tm2['number'])