writing query delete method
added queryset usage tests refactoring how column names are stored and calculated
This commit is contained in:
@@ -56,6 +56,9 @@ class BaseColumn(object):
|
|||||||
self.default = default
|
self.default = default
|
||||||
self.null = null
|
self.null = null
|
||||||
|
|
||||||
|
#the column name in the model definition
|
||||||
|
self.column_name = None
|
||||||
|
|
||||||
self.value = None
|
self.value = None
|
||||||
|
|
||||||
#keep track of instantiation order
|
#keep track of instantiation order
|
||||||
@@ -112,17 +115,21 @@ class BaseColumn(object):
|
|||||||
"""
|
"""
|
||||||
Returns a column definition for CQL table definition
|
Returns a column definition for CQL table definition
|
||||||
"""
|
"""
|
||||||
dterms = [self.db_field, self.db_type]
|
dterms = [self.db_field_name, self.db_type]
|
||||||
#if self.primary_key:
|
#if self.primary_key:
|
||||||
#dterms.append('PRIMARY KEY')
|
#dterms.append('PRIMARY KEY')
|
||||||
return ' '.join(dterms)
|
return ' '.join(dterms)
|
||||||
|
|
||||||
def set_db_name(self, name):
|
def set_column_name(self, name):
|
||||||
"""
|
"""
|
||||||
Sets the column name during document class construction
|
Sets the column name during document class construction
|
||||||
This value will be ignored if db_field is set in __init__
|
This value will be ignored if db_field is set in __init__
|
||||||
"""
|
"""
|
||||||
self.db_field = self.db_field or name
|
self.column_name = name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_field_name(self):
|
||||||
|
return self.db_field or self.column_name
|
||||||
|
|
||||||
class Bytes(BaseColumn):
|
class Bytes(BaseColumn):
|
||||||
db_type = 'blob'
|
db_type = 'blob'
|
||||||
|
|||||||
@@ -62,13 +62,11 @@ class BaseModel(object):
|
|||||||
def save(self):
|
def save(self):
|
||||||
is_new = self.pk is None
|
is_new = self.pk is None
|
||||||
self.validate()
|
self.validate()
|
||||||
#self.objects._save_instance(self)
|
|
||||||
self.objects.save(self)
|
self.objects.save(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
""" Deletes this instance """
|
""" Deletes this instance """
|
||||||
#self.objects._delete_instance(self)
|
|
||||||
self.objects.delete_instance(self)
|
self.objects.delete_instance(self)
|
||||||
|
|
||||||
|
|
||||||
@@ -84,7 +82,7 @@ class ModelMetaClass(type):
|
|||||||
|
|
||||||
def _transform_column(col_name, col_obj):
|
def _transform_column(col_name, col_obj):
|
||||||
_columns[col_name] = col_obj
|
_columns[col_name] = col_obj
|
||||||
col_obj.set_db_name(col_name)
|
col_obj.set_column_name(col_name)
|
||||||
#set properties
|
#set properties
|
||||||
_get = lambda self: self._values[col_name].getval()
|
_get = lambda self: self._values[col_name].getval()
|
||||||
_set = lambda self, val: self._values[col_name].setval(val)
|
_set = lambda self, val: self._values[col_name].setval(val)
|
||||||
@@ -94,7 +92,6 @@ class ModelMetaClass(type):
|
|||||||
else:
|
else:
|
||||||
attrs[col_name] = property(_get, _set, _del)
|
attrs[col_name] = property(_get, _set, _del)
|
||||||
|
|
||||||
|
|
||||||
column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.BaseColumn)]
|
column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.BaseColumn)]
|
||||||
column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position))
|
column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position))
|
||||||
|
|
||||||
@@ -116,9 +113,9 @@ class ModelMetaClass(type):
|
|||||||
#check for duplicate column names
|
#check for duplicate column names
|
||||||
col_names = set()
|
col_names = set()
|
||||||
for k,v in _columns.items():
|
for k,v in _columns.items():
|
||||||
if v.db_field in col_names:
|
if v.db_field_name in col_names:
|
||||||
raise ModelException("{} defines the column {} more than once".format(name, v.db_field))
|
raise ModelException("{} defines the column {} more than once".format(name, v.db_field_name))
|
||||||
col_names.add(v.db_field)
|
col_names.add(v.db_field_name)
|
||||||
|
|
||||||
#get column family name
|
#get column family name
|
||||||
cf_name = attrs.pop('db_name', name)
|
cf_name = attrs.pop('db_name', name)
|
||||||
@@ -126,7 +123,7 @@ class ModelMetaClass(type):
|
|||||||
#create db_name -> model name map for loading
|
#create db_name -> model name map for loading
|
||||||
db_map = {}
|
db_map = {}
|
||||||
for name, col in _columns.items():
|
for name, col in _columns.items():
|
||||||
db_map[col.db_field] = name
|
db_map[col.db_field_name] = name
|
||||||
|
|
||||||
#add management members to the class
|
#add management members to the class
|
||||||
attrs['_columns'] = _columns
|
attrs['_columns'] = _columns
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class QueryOperator(object):
|
|||||||
Returns this operator's portion of the WHERE clause
|
Returns this operator's portion of the WHERE clause
|
||||||
:param valname: the dict key that this operator's compare value will be found in
|
:param valname: the dict key that this operator's compare value will be found in
|
||||||
"""
|
"""
|
||||||
return '{} {} :{}'.format(self.column.db_field, self.cql_symbol, self.identifier)
|
return '{} {} :{}'.format(self.column.db_field_name, self.cql_symbol, self.identifier)
|
||||||
|
|
||||||
def validate_operator(self):
|
def validate_operator(self):
|
||||||
"""
|
"""
|
||||||
@@ -140,6 +140,12 @@ class QuerySet(object):
|
|||||||
|
|
||||||
self._cursor = None
|
self._cursor = None
|
||||||
|
|
||||||
|
def __unicode__(self):
|
||||||
|
return self._select_query()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.__unicode__())
|
||||||
|
|
||||||
#----query generation / execution----
|
#----query generation / execution----
|
||||||
|
|
||||||
def _validate_where_syntax(self):
|
def _validate_where_syntax(self):
|
||||||
@@ -169,12 +175,12 @@ class QuerySet(object):
|
|||||||
if count:
|
if count:
|
||||||
qs += ['SELECT COUNT(*)']
|
qs += ['SELECT COUNT(*)']
|
||||||
else:
|
else:
|
||||||
fields = self.models._columns.keys()
|
fields = self.model._columns.keys()
|
||||||
if self._defer_fields:
|
if self._defer_fields:
|
||||||
fields = [f for f in fields if f not in self._defer_fields]
|
fields = [f for f in fields if f not in self._defer_fields]
|
||||||
elif self._only_fields:
|
elif self._only_fields:
|
||||||
fields = [f for f in fields if f in self._only_fields]
|
fields = [f for f in fields if f in self._only_fields]
|
||||||
db_fields = [self.model._columns[f].db_fields for f in fields]
|
db_fields = [self.model._columns[f].db_field_name for f in fields]
|
||||||
qs += ['SELECT {}'.format(', '.join(db_fields))]
|
qs += ['SELECT {}'.format(', '.join(db_fields))]
|
||||||
|
|
||||||
qs += ['FROM {}'.format(self.column_family_name)]
|
qs += ['FROM {}'.format(self.column_family_name)]
|
||||||
@@ -182,7 +188,10 @@ class QuerySet(object):
|
|||||||
if self._where:
|
if self._where:
|
||||||
qs += ['WHERE {}'.format(self._where_clause())]
|
qs += ['WHERE {}'.format(self._where_clause())]
|
||||||
|
|
||||||
#TODO: add support for limit, start, order by, and reverse
|
if not count:
|
||||||
|
#TODO: add support for limit, start, order by, and reverse
|
||||||
|
pass
|
||||||
|
|
||||||
return ' '.join(qs)
|
return ' '.join(qs)
|
||||||
|
|
||||||
#----Reads------
|
#----Reads------
|
||||||
@@ -191,6 +200,7 @@ class QuerySet(object):
|
|||||||
conn = get_connection()
|
conn = get_connection()
|
||||||
self._cursor = conn.cursor()
|
self._cursor = conn.cursor()
|
||||||
self._cursor.execute(self._select_query(), self._where_values())
|
self._cursor.execute(self._select_query(), self._where_values())
|
||||||
|
self._rowcount = self._cursor.rowcount
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _construct_instance(self, values):
|
def _construct_instance(self, values):
|
||||||
@@ -267,7 +277,10 @@ class QuerySet(object):
|
|||||||
con = get_connection()
|
con = get_connection()
|
||||||
cur = con.cursor()
|
cur = con.cursor()
|
||||||
cur.execute(self._select_query(count=True), self._where_values())
|
cur.execute(self._select_query(count=True), self._where_values())
|
||||||
return cur.fetchone()
|
return cur.fetchone()[0]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.count()
|
||||||
|
|
||||||
def find(self, pk):
|
def find(self, pk):
|
||||||
"""
|
"""
|
||||||
@@ -319,23 +332,14 @@ class QuerySet(object):
|
|||||||
prior to calling this.
|
prior to calling this.
|
||||||
"""
|
"""
|
||||||
assert type(instance) == self.model
|
assert type(instance) == self.model
|
||||||
|
|
||||||
#organize data
|
#organize data
|
||||||
value_pairs = []
|
value_pairs = []
|
||||||
|
|
||||||
#get pk
|
|
||||||
col = self.model._columns[self.model._pk_name]
|
|
||||||
values = instance.as_dict()
|
values = instance.as_dict()
|
||||||
value_pairs += [(col.db_field, values.get(self.model._pk_name))]
|
|
||||||
|
|
||||||
#get defined fields and their column names
|
#get defined fields and their column names
|
||||||
for name, col in self.model._columns.items():
|
for name, col in self.model._columns.items():
|
||||||
if col.is_primary_key: continue
|
value_pairs += [(col.db_field_name, values.get(name))]
|
||||||
value_pairs += [(col.db_field, values.get(name))]
|
|
||||||
|
|
||||||
#add dynamic fields
|
|
||||||
for key, val in values.items():
|
|
||||||
if key in self.model._columns: continue
|
|
||||||
value_pairs += [(key, val)]
|
|
||||||
|
|
||||||
#construct query string
|
#construct query string
|
||||||
field_names = zip(*value_pairs)[0]
|
field_names = zip(*value_pairs)[0]
|
||||||
@@ -357,7 +361,21 @@ class QuerySet(object):
|
|||||||
def delete(self, columns=[]):
|
def delete(self, columns=[]):
|
||||||
"""
|
"""
|
||||||
Deletes the contents of a query
|
Deletes the contents of a query
|
||||||
|
|
||||||
|
:returns: number of rows deleted
|
||||||
"""
|
"""
|
||||||
|
qs = ['DELETE FROM {}'.format(self.column_family_name)]
|
||||||
|
if self._where:
|
||||||
|
qs += ['WHERE {}'.format(self._where_clause())]
|
||||||
|
qs = ' '.join(qs)
|
||||||
|
|
||||||
|
#TODO: Return number of rows deleted
|
||||||
|
con = get_connection()
|
||||||
|
cur = con.cursor()
|
||||||
|
cur.execute(qs, self._where_values())
|
||||||
|
return cur.fetchone()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def delete_instance(self, instance):
|
def delete_instance(self, instance):
|
||||||
""" Deletes one instance """
|
""" Deletes one instance """
|
||||||
@@ -378,8 +396,8 @@ class QuerySet(object):
|
|||||||
pkeys = []
|
pkeys = []
|
||||||
qtypes = []
|
qtypes = []
|
||||||
def add_column(col):
|
def add_column(col):
|
||||||
s = '{} {}'.format(col.db_field, col.db_type)
|
s = '{} {}'.format(col.db_field_name, col.db_type)
|
||||||
if col.primary_key: pkeys.append(col.db_field)
|
if col.primary_key: pkeys.append(col.db_field_name)
|
||||||
qtypes.append(s)
|
qtypes.append(s)
|
||||||
#add_column(self.model._columns[self.model._pk_name])
|
#add_column(self.model._columns[self.model._pk_name])
|
||||||
for name, col in self.model._columns.items():
|
for name, col in self.model._columns.items():
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ from cqlengine import query
|
|||||||
class TestModel(Model):
|
class TestModel(Model):
|
||||||
test_id = columns.Integer(primary_key=True)
|
test_id = columns.Integer(primary_key=True)
|
||||||
attempt_id = columns.Integer(primary_key=True)
|
attempt_id = columns.Integer(primary_key=True)
|
||||||
descriptions = columns.Text()
|
description = columns.Text()
|
||||||
expected_result = columns.Integer()
|
expected_result = columns.Integer()
|
||||||
test_result = columns.Integer(index=True)
|
test_result = columns.Integer(index=True)
|
||||||
|
|
||||||
class TestQuerySet(BaseCassEngTestCase):
|
class TestQuerySetOperation(BaseCassEngTestCase):
|
||||||
|
|
||||||
def test_query_filter_parsing(self):
|
def test_query_filter_parsing(self):
|
||||||
"""
|
"""
|
||||||
@@ -54,7 +54,7 @@ class TestQuerySet(BaseCassEngTestCase):
|
|||||||
assert where == 'test_id = :{} AND expected_result >= :{}'.format(*ids)
|
assert where == 'test_id = :{} AND expected_result >= :{}'.format(*ids)
|
||||||
|
|
||||||
|
|
||||||
def test_querystring_construction(self):
|
def test_querystring_generation(self):
|
||||||
"""
|
"""
|
||||||
Tests the select querystring creation
|
Tests the select querystring creation
|
||||||
"""
|
"""
|
||||||
@@ -101,3 +101,60 @@ class TestQuerySet(BaseCassEngTestCase):
|
|||||||
"""
|
"""
|
||||||
Tests that setting only or defer fields that don't exist raises an exception
|
Tests that setting only or defer fields that don't exist raises an exception
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class TestQuerySetUsage(BaseCassEngTestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super(TestQuerySetUsage, cls).setUpClass()
|
||||||
|
try: TestModel.objects._delete_column_family()
|
||||||
|
except: pass
|
||||||
|
TestModel.objects._create_column_family()
|
||||||
|
|
||||||
|
TestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30)
|
||||||
|
TestModel.objects.create(test_id=0, attempt_id=1, description='try2', expected_result=10, test_result=30)
|
||||||
|
TestModel.objects.create(test_id=0, attempt_id=2, description='try3', expected_result=15, test_result=30)
|
||||||
|
TestModel.objects.create(test_id=0, attempt_id=3, description='try4', expected_result=20, test_result=25)
|
||||||
|
|
||||||
|
TestModel.objects.create(test_id=1, attempt_id=0, description='try5', expected_result=5, test_result=25)
|
||||||
|
TestModel.objects.create(test_id=1, attempt_id=1, description='try6', expected_result=10, test_result=25)
|
||||||
|
TestModel.objects.create(test_id=1, attempt_id=2, description='try7', expected_result=15, test_result=25)
|
||||||
|
TestModel.objects.create(test_id=1, attempt_id=3, description='try8', expected_result=20, test_result=20)
|
||||||
|
|
||||||
|
TestModel.objects.create(test_id=2, attempt_id=0, description='try9', expected_result=50, test_result=40)
|
||||||
|
TestModel.objects.create(test_id=2, attempt_id=1, description='try10', expected_result=60, test_result=40)
|
||||||
|
TestModel.objects.create(test_id=2, attempt_id=2, description='try11', expected_result=70, test_result=45)
|
||||||
|
TestModel.objects.create(test_id=2, attempt_id=3, description='try12', expected_result=75, test_result=45)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
super(TestQuerySetUsage, cls).tearDownClass()
|
||||||
|
TestModel.objects._delete_column_family()
|
||||||
|
|
||||||
|
def test_count(self):
|
||||||
|
q = TestModel.objects(test_id=0)
|
||||||
|
assert q.count() == 4
|
||||||
|
|
||||||
|
def test_iteration(self):
|
||||||
|
q = TestModel.objects(test_id=0)
|
||||||
|
#tuple of expected attempt_id, expected_result values
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
|
||||||
|
for t in q:
|
||||||
|
val = t.attempt_id, t.expected_result
|
||||||
|
assert val in compare_set
|
||||||
|
compare_set.remove(val)
|
||||||
|
assert len(compare_set) == 0
|
||||||
|
|
||||||
|
q = TestModel.objects(attempt_id=3)
|
||||||
|
assert len(q) == 3
|
||||||
|
#tuple of expected test_id, expected_result values
|
||||||
|
compare_set = set([(0,20), (1,20), (2,75)])
|
||||||
|
for t in q:
|
||||||
|
val = t.test_id, t.expected_result
|
||||||
|
assert val in compare_set
|
||||||
|
compare_set.remove(val)
|
||||||
|
assert len(compare_set) == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user