From b7d91fe869fef8fcdebafc83b771db907df73f16 Mon Sep 17 00:00:00 2001 From: Blake Eggleston Date: Thu, 22 Nov 2012 23:16:49 -0800 Subject: [PATCH] writing query delete method added queryset usage tests refactoring how column names are stored and calculated --- cqlengine/columns.py | 13 ++++-- cqlengine/models.py | 13 ++---- cqlengine/query.py | 54 ++++++++++++++-------- cqlengine/tests/query/test_queryset.py | 63 ++++++++++++++++++++++++-- 4 files changed, 111 insertions(+), 32 deletions(-) diff --git a/cqlengine/columns.py b/cqlengine/columns.py index e81e8930..0eeb8051 100644 --- a/cqlengine/columns.py +++ b/cqlengine/columns.py @@ -56,6 +56,9 @@ class BaseColumn(object): self.default = default self.null = null + #the column name in the model definition + self.column_name = None + self.value = None #keep track of instantiation order @@ -112,17 +115,21 @@ class BaseColumn(object): """ Returns a column definition for CQL table definition """ - dterms = [self.db_field, self.db_type] + dterms = [self.db_field_name, self.db_type] #if self.primary_key: #dterms.append('PRIMARY KEY') return ' '.join(dterms) - def set_db_name(self, name): + def set_column_name(self, name): """ Sets the column name during document class construction This value will be ignored if db_field is set in __init__ """ - self.db_field = self.db_field or name + self.column_name = name + + @property + def db_field_name(self): + return self.db_field or self.column_name class Bytes(BaseColumn): db_type = 'blob' diff --git a/cqlengine/models.py b/cqlengine/models.py index ca51a21e..2d9de4ed 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -62,13 +62,11 @@ class BaseModel(object): def save(self): is_new = self.pk is None self.validate() - #self.objects._save_instance(self) self.objects.save(self) return self def delete(self): """ Deletes this instance """ - #self.objects._delete_instance(self) self.objects.delete_instance(self) @@ -84,7 +82,7 @@ class ModelMetaClass(type): def _transform_column(col_name, col_obj): _columns[col_name] = col_obj - col_obj.set_db_name(col_name) + col_obj.set_column_name(col_name) #set properties _get = lambda self: self._values[col_name].getval() _set = lambda self, val: self._values[col_name].setval(val) @@ -94,7 +92,6 @@ class ModelMetaClass(type): else: attrs[col_name] = property(_get, _set, _del) - column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.BaseColumn)] column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position)) @@ -116,9 +113,9 @@ class ModelMetaClass(type): #check for duplicate column names col_names = set() for k,v in _columns.items(): - if v.db_field in col_names: - raise ModelException("{} defines the column {} more than once".format(name, v.db_field)) - col_names.add(v.db_field) + if v.db_field_name in col_names: + raise ModelException("{} defines the column {} more than once".format(name, v.db_field_name)) + col_names.add(v.db_field_name) #get column family name cf_name = attrs.pop('db_name', name) @@ -126,7 +123,7 @@ class ModelMetaClass(type): #create db_name -> model name map for loading db_map = {} for name, col in _columns.items(): - db_map[col.db_field] = name + db_map[col.db_field_name] = name #add management members to the class attrs['_columns'] = _columns diff --git a/cqlengine/query.py b/cqlengine/query.py index 7b6c91b0..b2c8c9e8 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -39,7 +39,7 @@ class QueryOperator(object): Returns this operator's portion of the WHERE clause :param valname: the dict key that this operator's compare value will be found in """ - return '{} {} :{}'.format(self.column.db_field, self.cql_symbol, self.identifier) + return '{} {} :{}'.format(self.column.db_field_name, self.cql_symbol, self.identifier) def validate_operator(self): """ @@ -140,6 +140,12 @@ class QuerySet(object): self._cursor = None + def __unicode__(self): + return self._select_query() + + def __str__(self): + return str(self.__unicode__()) + #----query generation / execution---- def _validate_where_syntax(self): @@ -169,12 +175,12 @@ class QuerySet(object): if count: qs += ['SELECT COUNT(*)'] else: - fields = self.models._columns.keys() + fields = self.model._columns.keys() if self._defer_fields: fields = [f for f in fields if f not in self._defer_fields] elif self._only_fields: fields = [f for f in fields if f in self._only_fields] - db_fields = [self.model._columns[f].db_fields for f in fields] + db_fields = [self.model._columns[f].db_field_name for f in fields] qs += ['SELECT {}'.format(', '.join(db_fields))] qs += ['FROM {}'.format(self.column_family_name)] @@ -182,7 +188,10 @@ class QuerySet(object): if self._where: qs += ['WHERE {}'.format(self._where_clause())] - #TODO: add support for limit, start, order by, and reverse + if not count: + #TODO: add support for limit, start, order by, and reverse + pass + return ' '.join(qs) #----Reads------ @@ -191,6 +200,7 @@ class QuerySet(object): conn = get_connection() self._cursor = conn.cursor() self._cursor.execute(self._select_query(), self._where_values()) + self._rowcount = self._cursor.rowcount return self def _construct_instance(self, values): @@ -267,7 +277,10 @@ class QuerySet(object): con = get_connection() cur = con.cursor() cur.execute(self._select_query(count=True), self._where_values()) - return cur.fetchone() + return cur.fetchone()[0] + + def __len__(self): + return self.count() def find(self, pk): """ @@ -319,23 +332,14 @@ class QuerySet(object): prior to calling this. """ assert type(instance) == self.model + #organize data value_pairs = [] - - #get pk - col = self.model._columns[self.model._pk_name] values = instance.as_dict() - value_pairs += [(col.db_field, values.get(self.model._pk_name))] #get defined fields and their column names for name, col in self.model._columns.items(): - if col.is_primary_key: continue - value_pairs += [(col.db_field, values.get(name))] - - #add dynamic fields - for key, val in values.items(): - if key in self.model._columns: continue - value_pairs += [(key, val)] + value_pairs += [(col.db_field_name, values.get(name))] #construct query string field_names = zip(*value_pairs)[0] @@ -357,7 +361,21 @@ class QuerySet(object): def delete(self, columns=[]): """ Deletes the contents of a query + + :returns: number of rows deleted """ + qs = ['DELETE FROM {}'.format(self.column_family_name)] + if self._where: + qs += ['WHERE {}'.format(self._where_clause())] + qs = ' '.join(qs) + + #TODO: Return number of rows deleted + con = get_connection() + cur = con.cursor() + cur.execute(qs, self._where_values()) + return cur.fetchone() + + def delete_instance(self, instance): """ Deletes one instance """ @@ -378,8 +396,8 @@ class QuerySet(object): pkeys = [] qtypes = [] def add_column(col): - s = '{} {}'.format(col.db_field, col.db_type) - if col.primary_key: pkeys.append(col.db_field) + s = '{} {}'.format(col.db_field_name, col.db_type) + if col.primary_key: pkeys.append(col.db_field_name) qtypes.append(s) #add_column(self.model._columns[self.model._pk_name]) for name, col in self.model._columns.items(): diff --git a/cqlengine/tests/query/test_queryset.py b/cqlengine/tests/query/test_queryset.py index fe58dc72..9ac767b5 100644 --- a/cqlengine/tests/query/test_queryset.py +++ b/cqlengine/tests/query/test_queryset.py @@ -8,11 +8,11 @@ from cqlengine import query class TestModel(Model): test_id = columns.Integer(primary_key=True) attempt_id = columns.Integer(primary_key=True) - descriptions = columns.Text() + description = columns.Text() expected_result = columns.Integer() test_result = columns.Integer(index=True) -class TestQuerySet(BaseCassEngTestCase): +class TestQuerySetOperation(BaseCassEngTestCase): def test_query_filter_parsing(self): """ @@ -54,7 +54,7 @@ class TestQuerySet(BaseCassEngTestCase): assert where == 'test_id = :{} AND expected_result >= :{}'.format(*ids) - def test_querystring_construction(self): + def test_querystring_generation(self): """ Tests the select querystring creation """ @@ -101,3 +101,60 @@ class TestQuerySet(BaseCassEngTestCase): """ Tests that setting only or defer fields that don't exist raises an exception """ + +class TestQuerySetUsage(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestQuerySetUsage, cls).setUpClass() + try: TestModel.objects._delete_column_family() + except: pass + TestModel.objects._create_column_family() + + TestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30) + TestModel.objects.create(test_id=0, attempt_id=1, description='try2', expected_result=10, test_result=30) + TestModel.objects.create(test_id=0, attempt_id=2, description='try3', expected_result=15, test_result=30) + TestModel.objects.create(test_id=0, attempt_id=3, description='try4', expected_result=20, test_result=25) + + TestModel.objects.create(test_id=1, attempt_id=0, description='try5', expected_result=5, test_result=25) + TestModel.objects.create(test_id=1, attempt_id=1, description='try6', expected_result=10, test_result=25) + TestModel.objects.create(test_id=1, attempt_id=2, description='try7', expected_result=15, test_result=25) + TestModel.objects.create(test_id=1, attempt_id=3, description='try8', expected_result=20, test_result=20) + + TestModel.objects.create(test_id=2, attempt_id=0, description='try9', expected_result=50, test_result=40) + TestModel.objects.create(test_id=2, attempt_id=1, description='try10', expected_result=60, test_result=40) + TestModel.objects.create(test_id=2, attempt_id=2, description='try11', expected_result=70, test_result=45) + TestModel.objects.create(test_id=2, attempt_id=3, description='try12', expected_result=75, test_result=45) + + @classmethod + def tearDownClass(cls): + super(TestQuerySetUsage, cls).tearDownClass() + TestModel.objects._delete_column_family() + + def test_count(self): + q = TestModel.objects(test_id=0) + assert q.count() == 4 + + def test_iteration(self): + q = TestModel.objects(test_id=0) + #tuple of expected attempt_id, expected_result values + import ipdb; ipdb.set_trace() + compare_set = set([(0,5), (1,10), (2,15), (3,20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + q = TestModel.objects(attempt_id=3) + assert len(q) == 3 + #tuple of expected test_id, expected_result values + compare_set = set([(0,20), (1,20), (2,75)]) + for t in q: + val = t.test_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + +