writing query delete method

added queryset usage tests
refactoring how column names are stored and calculated
This commit is contained in:
Blake Eggleston
2012-11-22 23:16:49 -08:00
parent 1fcf809095
commit b7d91fe869
4 changed files with 111 additions and 32 deletions

View File

@@ -56,6 +56,9 @@ class BaseColumn(object):
self.default = default
self.null = null
#the column name in the model definition
self.column_name = None
self.value = None
#keep track of instantiation order
@@ -112,17 +115,21 @@ class BaseColumn(object):
"""
Returns a column definition for CQL table definition
"""
dterms = [self.db_field, self.db_type]
dterms = [self.db_field_name, self.db_type]
#if self.primary_key:
#dterms.append('PRIMARY KEY')
return ' '.join(dterms)
def set_db_name(self, name):
def set_column_name(self, name):
"""
Sets the column name during document class construction
This value will be ignored if db_field is set in __init__
"""
self.db_field = self.db_field or name
self.column_name = name
@property
def db_field_name(self):
return self.db_field or self.column_name
class Bytes(BaseColumn):
db_type = 'blob'

View File

@@ -62,13 +62,11 @@ class BaseModel(object):
def save(self):
is_new = self.pk is None
self.validate()
#self.objects._save_instance(self)
self.objects.save(self)
return self
def delete(self):
""" Deletes this instance """
#self.objects._delete_instance(self)
self.objects.delete_instance(self)
@@ -84,7 +82,7 @@ class ModelMetaClass(type):
def _transform_column(col_name, col_obj):
_columns[col_name] = col_obj
col_obj.set_db_name(col_name)
col_obj.set_column_name(col_name)
#set properties
_get = lambda self: self._values[col_name].getval()
_set = lambda self, val: self._values[col_name].setval(val)
@@ -94,7 +92,6 @@ class ModelMetaClass(type):
else:
attrs[col_name] = property(_get, _set, _del)
column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.BaseColumn)]
column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position))
@@ -116,9 +113,9 @@ class ModelMetaClass(type):
#check for duplicate column names
col_names = set()
for k,v in _columns.items():
if v.db_field in col_names:
raise ModelException("{} defines the column {} more than once".format(name, v.db_field))
col_names.add(v.db_field)
if v.db_field_name in col_names:
raise ModelException("{} defines the column {} more than once".format(name, v.db_field_name))
col_names.add(v.db_field_name)
#get column family name
cf_name = attrs.pop('db_name', name)
@@ -126,7 +123,7 @@ class ModelMetaClass(type):
#create db_name -> model name map for loading
db_map = {}
for name, col in _columns.items():
db_map[col.db_field] = name
db_map[col.db_field_name] = name
#add management members to the class
attrs['_columns'] = _columns

View File

@@ -39,7 +39,7 @@ class QueryOperator(object):
Returns this operator's portion of the WHERE clause
:param valname: the dict key that this operator's compare value will be found in
"""
return '{} {} :{}'.format(self.column.db_field, self.cql_symbol, self.identifier)
return '{} {} :{}'.format(self.column.db_field_name, self.cql_symbol, self.identifier)
def validate_operator(self):
"""
@@ -140,6 +140,12 @@ class QuerySet(object):
self._cursor = None
def __unicode__(self):
return self._select_query()
def __str__(self):
return str(self.__unicode__())
#----query generation / execution----
def _validate_where_syntax(self):
@@ -169,12 +175,12 @@ class QuerySet(object):
if count:
qs += ['SELECT COUNT(*)']
else:
fields = self.models._columns.keys()
fields = self.model._columns.keys()
if self._defer_fields:
fields = [f for f in fields if f not in self._defer_fields]
elif self._only_fields:
fields = [f for f in fields if f in self._only_fields]
db_fields = [self.model._columns[f].db_fields for f in fields]
db_fields = [self.model._columns[f].db_field_name for f in fields]
qs += ['SELECT {}'.format(', '.join(db_fields))]
qs += ['FROM {}'.format(self.column_family_name)]
@@ -182,7 +188,10 @@ class QuerySet(object):
if self._where:
qs += ['WHERE {}'.format(self._where_clause())]
#TODO: add support for limit, start, order by, and reverse
if not count:
#TODO: add support for limit, start, order by, and reverse
pass
return ' '.join(qs)
#----Reads------
@@ -191,6 +200,7 @@ class QuerySet(object):
conn = get_connection()
self._cursor = conn.cursor()
self._cursor.execute(self._select_query(), self._where_values())
self._rowcount = self._cursor.rowcount
return self
def _construct_instance(self, values):
@@ -267,7 +277,10 @@ class QuerySet(object):
con = get_connection()
cur = con.cursor()
cur.execute(self._select_query(count=True), self._where_values())
return cur.fetchone()
return cur.fetchone()[0]
def __len__(self):
return self.count()
def find(self, pk):
"""
@@ -319,23 +332,14 @@ class QuerySet(object):
prior to calling this.
"""
assert type(instance) == self.model
#organize data
value_pairs = []
#get pk
col = self.model._columns[self.model._pk_name]
values = instance.as_dict()
value_pairs += [(col.db_field, values.get(self.model._pk_name))]
#get defined fields and their column names
for name, col in self.model._columns.items():
if col.is_primary_key: continue
value_pairs += [(col.db_field, values.get(name))]
#add dynamic fields
for key, val in values.items():
if key in self.model._columns: continue
value_pairs += [(key, val)]
value_pairs += [(col.db_field_name, values.get(name))]
#construct query string
field_names = zip(*value_pairs)[0]
@@ -357,7 +361,21 @@ class QuerySet(object):
def delete(self, columns=[]):
"""
Deletes the contents of a query
:returns: number of rows deleted
"""
qs = ['DELETE FROM {}'.format(self.column_family_name)]
if self._where:
qs += ['WHERE {}'.format(self._where_clause())]
qs = ' '.join(qs)
#TODO: Return number of rows deleted
con = get_connection()
cur = con.cursor()
cur.execute(qs, self._where_values())
return cur.fetchone()
def delete_instance(self, instance):
""" Deletes one instance """
@@ -378,8 +396,8 @@ class QuerySet(object):
pkeys = []
qtypes = []
def add_column(col):
s = '{} {}'.format(col.db_field, col.db_type)
if col.primary_key: pkeys.append(col.db_field)
s = '{} {}'.format(col.db_field_name, col.db_type)
if col.primary_key: pkeys.append(col.db_field_name)
qtypes.append(s)
#add_column(self.model._columns[self.model._pk_name])
for name, col in self.model._columns.items():

View File

@@ -8,11 +8,11 @@ from cqlengine import query
class TestModel(Model):
test_id = columns.Integer(primary_key=True)
attempt_id = columns.Integer(primary_key=True)
descriptions = columns.Text()
description = columns.Text()
expected_result = columns.Integer()
test_result = columns.Integer(index=True)
class TestQuerySet(BaseCassEngTestCase):
class TestQuerySetOperation(BaseCassEngTestCase):
def test_query_filter_parsing(self):
"""
@@ -54,7 +54,7 @@ class TestQuerySet(BaseCassEngTestCase):
assert where == 'test_id = :{} AND expected_result >= :{}'.format(*ids)
def test_querystring_construction(self):
def test_querystring_generation(self):
"""
Tests the select querystring creation
"""
@@ -101,3 +101,60 @@ class TestQuerySet(BaseCassEngTestCase):
"""
Tests that setting only or defer fields that don't exist raises an exception
"""
class TestQuerySetUsage(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestQuerySetUsage, cls).setUpClass()
try: TestModel.objects._delete_column_family()
except: pass
TestModel.objects._create_column_family()
TestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30)
TestModel.objects.create(test_id=0, attempt_id=1, description='try2', expected_result=10, test_result=30)
TestModel.objects.create(test_id=0, attempt_id=2, description='try3', expected_result=15, test_result=30)
TestModel.objects.create(test_id=0, attempt_id=3, description='try4', expected_result=20, test_result=25)
TestModel.objects.create(test_id=1, attempt_id=0, description='try5', expected_result=5, test_result=25)
TestModel.objects.create(test_id=1, attempt_id=1, description='try6', expected_result=10, test_result=25)
TestModel.objects.create(test_id=1, attempt_id=2, description='try7', expected_result=15, test_result=25)
TestModel.objects.create(test_id=1, attempt_id=3, description='try8', expected_result=20, test_result=20)
TestModel.objects.create(test_id=2, attempt_id=0, description='try9', expected_result=50, test_result=40)
TestModel.objects.create(test_id=2, attempt_id=1, description='try10', expected_result=60, test_result=40)
TestModel.objects.create(test_id=2, attempt_id=2, description='try11', expected_result=70, test_result=45)
TestModel.objects.create(test_id=2, attempt_id=3, description='try12', expected_result=75, test_result=45)
@classmethod
def tearDownClass(cls):
super(TestQuerySetUsage, cls).tearDownClass()
TestModel.objects._delete_column_family()
def test_count(self):
q = TestModel.objects(test_id=0)
assert q.count() == 4
def test_iteration(self):
q = TestModel.objects(test_id=0)
#tuple of expected attempt_id, expected_result values
import ipdb; ipdb.set_trace()
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
q = TestModel.objects(attempt_id=3)
assert len(q) == 3
#tuple of expected test_id, expected_result values
compare_set = set([(0,20), (1,20), (2,75)])
for t in q:
val = t.test_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0