writing query delete method
added queryset usage tests refactoring how column names are stored and calculated
This commit is contained in:
@@ -56,6 +56,9 @@ class BaseColumn(object):
|
||||
self.default = default
|
||||
self.null = null
|
||||
|
||||
#the column name in the model definition
|
||||
self.column_name = None
|
||||
|
||||
self.value = None
|
||||
|
||||
#keep track of instantiation order
|
||||
@@ -112,17 +115,21 @@ class BaseColumn(object):
|
||||
"""
|
||||
Returns a column definition for CQL table definition
|
||||
"""
|
||||
dterms = [self.db_field, self.db_type]
|
||||
dterms = [self.db_field_name, self.db_type]
|
||||
#if self.primary_key:
|
||||
#dterms.append('PRIMARY KEY')
|
||||
return ' '.join(dterms)
|
||||
|
||||
def set_db_name(self, name):
|
||||
def set_column_name(self, name):
|
||||
"""
|
||||
Sets the column name during document class construction
|
||||
This value will be ignored if db_field is set in __init__
|
||||
"""
|
||||
self.db_field = self.db_field or name
|
||||
self.column_name = name
|
||||
|
||||
@property
|
||||
def db_field_name(self):
|
||||
return self.db_field or self.column_name
|
||||
|
||||
class Bytes(BaseColumn):
|
||||
db_type = 'blob'
|
||||
|
||||
@@ -62,13 +62,11 @@ class BaseModel(object):
|
||||
def save(self):
|
||||
is_new = self.pk is None
|
||||
self.validate()
|
||||
#self.objects._save_instance(self)
|
||||
self.objects.save(self)
|
||||
return self
|
||||
|
||||
def delete(self):
|
||||
""" Deletes this instance """
|
||||
#self.objects._delete_instance(self)
|
||||
self.objects.delete_instance(self)
|
||||
|
||||
|
||||
@@ -84,7 +82,7 @@ class ModelMetaClass(type):
|
||||
|
||||
def _transform_column(col_name, col_obj):
|
||||
_columns[col_name] = col_obj
|
||||
col_obj.set_db_name(col_name)
|
||||
col_obj.set_column_name(col_name)
|
||||
#set properties
|
||||
_get = lambda self: self._values[col_name].getval()
|
||||
_set = lambda self, val: self._values[col_name].setval(val)
|
||||
@@ -94,7 +92,6 @@ class ModelMetaClass(type):
|
||||
else:
|
||||
attrs[col_name] = property(_get, _set, _del)
|
||||
|
||||
|
||||
column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.BaseColumn)]
|
||||
column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position))
|
||||
|
||||
@@ -116,9 +113,9 @@ class ModelMetaClass(type):
|
||||
#check for duplicate column names
|
||||
col_names = set()
|
||||
for k,v in _columns.items():
|
||||
if v.db_field in col_names:
|
||||
raise ModelException("{} defines the column {} more than once".format(name, v.db_field))
|
||||
col_names.add(v.db_field)
|
||||
if v.db_field_name in col_names:
|
||||
raise ModelException("{} defines the column {} more than once".format(name, v.db_field_name))
|
||||
col_names.add(v.db_field_name)
|
||||
|
||||
#get column family name
|
||||
cf_name = attrs.pop('db_name', name)
|
||||
@@ -126,7 +123,7 @@ class ModelMetaClass(type):
|
||||
#create db_name -> model name map for loading
|
||||
db_map = {}
|
||||
for name, col in _columns.items():
|
||||
db_map[col.db_field] = name
|
||||
db_map[col.db_field_name] = name
|
||||
|
||||
#add management members to the class
|
||||
attrs['_columns'] = _columns
|
||||
|
||||
@@ -39,7 +39,7 @@ class QueryOperator(object):
|
||||
Returns this operator's portion of the WHERE clause
|
||||
:param valname: the dict key that this operator's compare value will be found in
|
||||
"""
|
||||
return '{} {} :{}'.format(self.column.db_field, self.cql_symbol, self.identifier)
|
||||
return '{} {} :{}'.format(self.column.db_field_name, self.cql_symbol, self.identifier)
|
||||
|
||||
def validate_operator(self):
|
||||
"""
|
||||
@@ -140,6 +140,12 @@ class QuerySet(object):
|
||||
|
||||
self._cursor = None
|
||||
|
||||
def __unicode__(self):
|
||||
return self._select_query()
|
||||
|
||||
def __str__(self):
|
||||
return str(self.__unicode__())
|
||||
|
||||
#----query generation / execution----
|
||||
|
||||
def _validate_where_syntax(self):
|
||||
@@ -169,12 +175,12 @@ class QuerySet(object):
|
||||
if count:
|
||||
qs += ['SELECT COUNT(*)']
|
||||
else:
|
||||
fields = self.models._columns.keys()
|
||||
fields = self.model._columns.keys()
|
||||
if self._defer_fields:
|
||||
fields = [f for f in fields if f not in self._defer_fields]
|
||||
elif self._only_fields:
|
||||
fields = [f for f in fields if f in self._only_fields]
|
||||
db_fields = [self.model._columns[f].db_fields for f in fields]
|
||||
db_fields = [self.model._columns[f].db_field_name for f in fields]
|
||||
qs += ['SELECT {}'.format(', '.join(db_fields))]
|
||||
|
||||
qs += ['FROM {}'.format(self.column_family_name)]
|
||||
@@ -182,7 +188,10 @@ class QuerySet(object):
|
||||
if self._where:
|
||||
qs += ['WHERE {}'.format(self._where_clause())]
|
||||
|
||||
#TODO: add support for limit, start, order by, and reverse
|
||||
if not count:
|
||||
#TODO: add support for limit, start, order by, and reverse
|
||||
pass
|
||||
|
||||
return ' '.join(qs)
|
||||
|
||||
#----Reads------
|
||||
@@ -191,6 +200,7 @@ class QuerySet(object):
|
||||
conn = get_connection()
|
||||
self._cursor = conn.cursor()
|
||||
self._cursor.execute(self._select_query(), self._where_values())
|
||||
self._rowcount = self._cursor.rowcount
|
||||
return self
|
||||
|
||||
def _construct_instance(self, values):
|
||||
@@ -267,7 +277,10 @@ class QuerySet(object):
|
||||
con = get_connection()
|
||||
cur = con.cursor()
|
||||
cur.execute(self._select_query(count=True), self._where_values())
|
||||
return cur.fetchone()
|
||||
return cur.fetchone()[0]
|
||||
|
||||
def __len__(self):
|
||||
return self.count()
|
||||
|
||||
def find(self, pk):
|
||||
"""
|
||||
@@ -319,23 +332,14 @@ class QuerySet(object):
|
||||
prior to calling this.
|
||||
"""
|
||||
assert type(instance) == self.model
|
||||
|
||||
#organize data
|
||||
value_pairs = []
|
||||
|
||||
#get pk
|
||||
col = self.model._columns[self.model._pk_name]
|
||||
values = instance.as_dict()
|
||||
value_pairs += [(col.db_field, values.get(self.model._pk_name))]
|
||||
|
||||
#get defined fields and their column names
|
||||
for name, col in self.model._columns.items():
|
||||
if col.is_primary_key: continue
|
||||
value_pairs += [(col.db_field, values.get(name))]
|
||||
|
||||
#add dynamic fields
|
||||
for key, val in values.items():
|
||||
if key in self.model._columns: continue
|
||||
value_pairs += [(key, val)]
|
||||
value_pairs += [(col.db_field_name, values.get(name))]
|
||||
|
||||
#construct query string
|
||||
field_names = zip(*value_pairs)[0]
|
||||
@@ -357,7 +361,21 @@ class QuerySet(object):
|
||||
def delete(self, columns=[]):
|
||||
"""
|
||||
Deletes the contents of a query
|
||||
|
||||
:returns: number of rows deleted
|
||||
"""
|
||||
qs = ['DELETE FROM {}'.format(self.column_family_name)]
|
||||
if self._where:
|
||||
qs += ['WHERE {}'.format(self._where_clause())]
|
||||
qs = ' '.join(qs)
|
||||
|
||||
#TODO: Return number of rows deleted
|
||||
con = get_connection()
|
||||
cur = con.cursor()
|
||||
cur.execute(qs, self._where_values())
|
||||
return cur.fetchone()
|
||||
|
||||
|
||||
|
||||
def delete_instance(self, instance):
|
||||
""" Deletes one instance """
|
||||
@@ -378,8 +396,8 @@ class QuerySet(object):
|
||||
pkeys = []
|
||||
qtypes = []
|
||||
def add_column(col):
|
||||
s = '{} {}'.format(col.db_field, col.db_type)
|
||||
if col.primary_key: pkeys.append(col.db_field)
|
||||
s = '{} {}'.format(col.db_field_name, col.db_type)
|
||||
if col.primary_key: pkeys.append(col.db_field_name)
|
||||
qtypes.append(s)
|
||||
#add_column(self.model._columns[self.model._pk_name])
|
||||
for name, col in self.model._columns.items():
|
||||
|
||||
@@ -8,11 +8,11 @@ from cqlengine import query
|
||||
class TestModel(Model):
|
||||
test_id = columns.Integer(primary_key=True)
|
||||
attempt_id = columns.Integer(primary_key=True)
|
||||
descriptions = columns.Text()
|
||||
description = columns.Text()
|
||||
expected_result = columns.Integer()
|
||||
test_result = columns.Integer(index=True)
|
||||
|
||||
class TestQuerySet(BaseCassEngTestCase):
|
||||
class TestQuerySetOperation(BaseCassEngTestCase):
|
||||
|
||||
def test_query_filter_parsing(self):
|
||||
"""
|
||||
@@ -54,7 +54,7 @@ class TestQuerySet(BaseCassEngTestCase):
|
||||
assert where == 'test_id = :{} AND expected_result >= :{}'.format(*ids)
|
||||
|
||||
|
||||
def test_querystring_construction(self):
|
||||
def test_querystring_generation(self):
|
||||
"""
|
||||
Tests the select querystring creation
|
||||
"""
|
||||
@@ -101,3 +101,60 @@ class TestQuerySet(BaseCassEngTestCase):
|
||||
"""
|
||||
Tests that setting only or defer fields that don't exist raises an exception
|
||||
"""
|
||||
|
||||
class TestQuerySetUsage(BaseCassEngTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestQuerySetUsage, cls).setUpClass()
|
||||
try: TestModel.objects._delete_column_family()
|
||||
except: pass
|
||||
TestModel.objects._create_column_family()
|
||||
|
||||
TestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30)
|
||||
TestModel.objects.create(test_id=0, attempt_id=1, description='try2', expected_result=10, test_result=30)
|
||||
TestModel.objects.create(test_id=0, attempt_id=2, description='try3', expected_result=15, test_result=30)
|
||||
TestModel.objects.create(test_id=0, attempt_id=3, description='try4', expected_result=20, test_result=25)
|
||||
|
||||
TestModel.objects.create(test_id=1, attempt_id=0, description='try5', expected_result=5, test_result=25)
|
||||
TestModel.objects.create(test_id=1, attempt_id=1, description='try6', expected_result=10, test_result=25)
|
||||
TestModel.objects.create(test_id=1, attempt_id=2, description='try7', expected_result=15, test_result=25)
|
||||
TestModel.objects.create(test_id=1, attempt_id=3, description='try8', expected_result=20, test_result=20)
|
||||
|
||||
TestModel.objects.create(test_id=2, attempt_id=0, description='try9', expected_result=50, test_result=40)
|
||||
TestModel.objects.create(test_id=2, attempt_id=1, description='try10', expected_result=60, test_result=40)
|
||||
TestModel.objects.create(test_id=2, attempt_id=2, description='try11', expected_result=70, test_result=45)
|
||||
TestModel.objects.create(test_id=2, attempt_id=3, description='try12', expected_result=75, test_result=45)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(TestQuerySetUsage, cls).tearDownClass()
|
||||
TestModel.objects._delete_column_family()
|
||||
|
||||
def test_count(self):
|
||||
q = TestModel.objects(test_id=0)
|
||||
assert q.count() == 4
|
||||
|
||||
def test_iteration(self):
|
||||
q = TestModel.objects(test_id=0)
|
||||
#tuple of expected attempt_id, expected_result values
|
||||
import ipdb; ipdb.set_trace()
|
||||
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
|
||||
for t in q:
|
||||
val = t.attempt_id, t.expected_result
|
||||
assert val in compare_set
|
||||
compare_set.remove(val)
|
||||
assert len(compare_set) == 0
|
||||
|
||||
q = TestModel.objects(attempt_id=3)
|
||||
assert len(q) == 3
|
||||
#tuple of expected test_id, expected_result values
|
||||
compare_set = set([(0,20), (1,20), (2,75)])
|
||||
for t in q:
|
||||
val = t.test_id, t.expected_result
|
||||
assert val in compare_set
|
||||
compare_set.remove(val)
|
||||
assert len(compare_set) == 0
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user