writing query delete method
added queryset usage tests refactoring how column names are stored and calculated
This commit is contained in:
		@@ -56,6 +56,9 @@ class BaseColumn(object):
 | 
				
			|||||||
        self.default = default
 | 
					        self.default = default
 | 
				
			||||||
        self.null = null
 | 
					        self.null = null
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #the column name in the model definition
 | 
				
			||||||
 | 
					        self.column_name = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.value = None
 | 
					        self.value = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #keep track of instantiation order
 | 
					        #keep track of instantiation order
 | 
				
			||||||
@@ -112,17 +115,21 @@ class BaseColumn(object):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        Returns a column definition for CQL table definition
 | 
					        Returns a column definition for CQL table definition
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        dterms = [self.db_field, self.db_type]
 | 
					        dterms = [self.db_field_name, self.db_type]
 | 
				
			||||||
        #if self.primary_key:
 | 
					        #if self.primary_key:
 | 
				
			||||||
            #dterms.append('PRIMARY KEY')
 | 
					            #dterms.append('PRIMARY KEY')
 | 
				
			||||||
        return ' '.join(dterms)
 | 
					        return ' '.join(dterms)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_db_name(self, name):
 | 
					    def set_column_name(self, name):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Sets the column name during document class construction
 | 
					        Sets the column name during document class construction
 | 
				
			||||||
        This value will be ignored if db_field is set in __init__
 | 
					        This value will be ignored if db_field is set in __init__
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.db_field = self.db_field or name
 | 
					        self.column_name = name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def db_field_name(self):
 | 
				
			||||||
 | 
					        return self.db_field or self.column_name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Bytes(BaseColumn):
 | 
					class Bytes(BaseColumn):
 | 
				
			||||||
    db_type = 'blob'
 | 
					    db_type = 'blob'
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -62,13 +62,11 @@ class BaseModel(object):
 | 
				
			|||||||
    def save(self):
 | 
					    def save(self):
 | 
				
			||||||
        is_new = self.pk is None
 | 
					        is_new = self.pk is None
 | 
				
			||||||
        self.validate()
 | 
					        self.validate()
 | 
				
			||||||
        #self.objects._save_instance(self)
 | 
					 | 
				
			||||||
        self.objects.save(self)
 | 
					        self.objects.save(self)
 | 
				
			||||||
        return self
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def delete(self):
 | 
					    def delete(self):
 | 
				
			||||||
        """ Deletes this instance """
 | 
					        """ Deletes this instance """
 | 
				
			||||||
        #self.objects._delete_instance(self)
 | 
					 | 
				
			||||||
        self.objects.delete_instance(self)
 | 
					        self.objects.delete_instance(self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -84,7 +82,7 @@ class ModelMetaClass(type):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        def _transform_column(col_name, col_obj):
 | 
					        def _transform_column(col_name, col_obj):
 | 
				
			||||||
            _columns[col_name] = col_obj
 | 
					            _columns[col_name] = col_obj
 | 
				
			||||||
            col_obj.set_db_name(col_name)
 | 
					            col_obj.set_column_name(col_name)
 | 
				
			||||||
            #set properties
 | 
					            #set properties
 | 
				
			||||||
            _get = lambda self: self._values[col_name].getval()
 | 
					            _get = lambda self: self._values[col_name].getval()
 | 
				
			||||||
            _set = lambda self, val: self._values[col_name].setval(val)
 | 
					            _set = lambda self, val: self._values[col_name].setval(val)
 | 
				
			||||||
@@ -94,7 +92,6 @@ class ModelMetaClass(type):
 | 
				
			|||||||
            else:
 | 
					            else:
 | 
				
			||||||
                attrs[col_name] = property(_get, _set, _del)
 | 
					                attrs[col_name] = property(_get, _set, _del)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.BaseColumn)]
 | 
					        column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.BaseColumn)]
 | 
				
			||||||
        column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position))
 | 
					        column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -116,9 +113,9 @@ class ModelMetaClass(type):
 | 
				
			|||||||
        #check for duplicate column names
 | 
					        #check for duplicate column names
 | 
				
			||||||
        col_names = set()
 | 
					        col_names = set()
 | 
				
			||||||
        for k,v in _columns.items():
 | 
					        for k,v in _columns.items():
 | 
				
			||||||
            if v.db_field in col_names:
 | 
					            if v.db_field_name in col_names:
 | 
				
			||||||
                raise ModelException("{} defines the column {} more than once".format(name, v.db_field))
 | 
					                raise ModelException("{} defines the column {} more than once".format(name, v.db_field_name))
 | 
				
			||||||
            col_names.add(v.db_field)
 | 
					            col_names.add(v.db_field_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #get column family name
 | 
					        #get column family name
 | 
				
			||||||
        cf_name = attrs.pop('db_name', name)
 | 
					        cf_name = attrs.pop('db_name', name)
 | 
				
			||||||
@@ -126,7 +123,7 @@ class ModelMetaClass(type):
 | 
				
			|||||||
        #create db_name -> model name map for loading
 | 
					        #create db_name -> model name map for loading
 | 
				
			||||||
        db_map = {}
 | 
					        db_map = {}
 | 
				
			||||||
        for name, col in _columns.items():
 | 
					        for name, col in _columns.items():
 | 
				
			||||||
            db_map[col.db_field] = name
 | 
					            db_map[col.db_field_name] = name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #add management members to the class
 | 
					        #add management members to the class
 | 
				
			||||||
        attrs['_columns'] = _columns
 | 
					        attrs['_columns'] = _columns
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -39,7 +39,7 @@ class QueryOperator(object):
 | 
				
			|||||||
        Returns this operator's portion of the WHERE clause
 | 
					        Returns this operator's portion of the WHERE clause
 | 
				
			||||||
        :param valname: the dict key that this operator's compare value will be found in
 | 
					        :param valname: the dict key that this operator's compare value will be found in
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        return '{} {} :{}'.format(self.column.db_field, self.cql_symbol, self.identifier)
 | 
					        return '{} {} :{}'.format(self.column.db_field_name, self.cql_symbol, self.identifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate_operator(self):
 | 
					    def validate_operator(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -140,6 +140,12 @@ class QuerySet(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self._cursor = None
 | 
					        self._cursor = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __unicode__(self):
 | 
				
			||||||
 | 
					        return self._select_query()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __str__(self):
 | 
				
			||||||
 | 
					        return str(self.__unicode__())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #----query generation / execution----
 | 
					    #----query generation / execution----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _validate_where_syntax(self):
 | 
					    def _validate_where_syntax(self):
 | 
				
			||||||
@@ -169,12 +175,12 @@ class QuerySet(object):
 | 
				
			|||||||
        if count:
 | 
					        if count:
 | 
				
			||||||
            qs += ['SELECT COUNT(*)']
 | 
					            qs += ['SELECT COUNT(*)']
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            fields = self.models._columns.keys()
 | 
					            fields = self.model._columns.keys()
 | 
				
			||||||
            if self._defer_fields:
 | 
					            if self._defer_fields:
 | 
				
			||||||
                fields = [f for f in fields if f not in self._defer_fields]
 | 
					                fields = [f for f in fields if f not in self._defer_fields]
 | 
				
			||||||
            elif self._only_fields:
 | 
					            elif self._only_fields:
 | 
				
			||||||
                fields = [f for f in fields if f in self._only_fields]
 | 
					                fields = [f for f in fields if f in self._only_fields]
 | 
				
			||||||
            db_fields = [self.model._columns[f].db_fields for f in fields]
 | 
					            db_fields = [self.model._columns[f].db_field_name for f in fields]
 | 
				
			||||||
            qs += ['SELECT {}'.format(', '.join(db_fields))]
 | 
					            qs += ['SELECT {}'.format(', '.join(db_fields))]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        qs += ['FROM {}'.format(self.column_family_name)]
 | 
					        qs += ['FROM {}'.format(self.column_family_name)]
 | 
				
			||||||
@@ -182,7 +188,10 @@ class QuerySet(object):
 | 
				
			|||||||
        if self._where:
 | 
					        if self._where:
 | 
				
			||||||
            qs += ['WHERE {}'.format(self._where_clause())]
 | 
					            qs += ['WHERE {}'.format(self._where_clause())]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not count:
 | 
				
			||||||
            #TODO: add support for limit, start, order by, and reverse
 | 
					            #TODO: add support for limit, start, order by, and reverse
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return ' '.join(qs)
 | 
					        return ' '.join(qs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #----Reads------
 | 
					    #----Reads------
 | 
				
			||||||
@@ -191,6 +200,7 @@ class QuerySet(object):
 | 
				
			|||||||
            conn = get_connection()
 | 
					            conn = get_connection()
 | 
				
			||||||
            self._cursor = conn.cursor()
 | 
					            self._cursor = conn.cursor()
 | 
				
			||||||
            self._cursor.execute(self._select_query(), self._where_values())
 | 
					            self._cursor.execute(self._select_query(), self._where_values())
 | 
				
			||||||
 | 
					            self._rowcount = self._cursor.rowcount
 | 
				
			||||||
        return self
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _construct_instance(self, values):
 | 
					    def _construct_instance(self, values):
 | 
				
			||||||
@@ -267,7 +277,10 @@ class QuerySet(object):
 | 
				
			|||||||
        con = get_connection()
 | 
					        con = get_connection()
 | 
				
			||||||
        cur = con.cursor()
 | 
					        cur = con.cursor()
 | 
				
			||||||
        cur.execute(self._select_query(count=True), self._where_values())
 | 
					        cur.execute(self._select_query(count=True), self._where_values())
 | 
				
			||||||
        return cur.fetchone()
 | 
					        return cur.fetchone()[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return self.count()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def find(self, pk):
 | 
					    def find(self, pk):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -319,23 +332,14 @@ class QuerySet(object):
 | 
				
			|||||||
        prior to calling this.
 | 
					        prior to calling this.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        assert type(instance) == self.model
 | 
					        assert type(instance) == self.model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #organize data
 | 
					        #organize data
 | 
				
			||||||
        value_pairs = []
 | 
					        value_pairs = []
 | 
				
			||||||
 | 
					 | 
				
			||||||
        #get pk
 | 
					 | 
				
			||||||
        col = self.model._columns[self.model._pk_name]
 | 
					 | 
				
			||||||
        values = instance.as_dict()
 | 
					        values = instance.as_dict()
 | 
				
			||||||
        value_pairs += [(col.db_field, values.get(self.model._pk_name))]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #get defined fields and their column names
 | 
					        #get defined fields and their column names
 | 
				
			||||||
        for name, col in self.model._columns.items():
 | 
					        for name, col in self.model._columns.items():
 | 
				
			||||||
            if col.is_primary_key: continue
 | 
					            value_pairs += [(col.db_field_name, values.get(name))]
 | 
				
			||||||
            value_pairs += [(col.db_field, values.get(name))]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        #add dynamic fields
 | 
					 | 
				
			||||||
        for key, val in values.items():
 | 
					 | 
				
			||||||
            if key in self.model._columns: continue
 | 
					 | 
				
			||||||
            value_pairs += [(key, val)]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #construct query string
 | 
					        #construct query string
 | 
				
			||||||
        field_names = zip(*value_pairs)[0]
 | 
					        field_names = zip(*value_pairs)[0]
 | 
				
			||||||
@@ -357,7 +361,21 @@ class QuerySet(object):
 | 
				
			|||||||
    def delete(self, columns=[]):
 | 
					    def delete(self, columns=[]):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Deletes the contents of a query
 | 
					        Deletes the contents of a query
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :returns: number of rows deleted
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        qs = ['DELETE FROM {}'.format(self.column_family_name)]
 | 
				
			||||||
 | 
					        if self._where:
 | 
				
			||||||
 | 
					            qs += ['WHERE {}'.format(self._where_clause())]
 | 
				
			||||||
 | 
					        qs = ' '.join(qs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #TODO: Return number of rows deleted
 | 
				
			||||||
 | 
					        con = get_connection()
 | 
				
			||||||
 | 
					        cur = con.cursor()
 | 
				
			||||||
 | 
					        cur.execute(qs, self._where_values())
 | 
				
			||||||
 | 
					        return cur.fetchone()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def delete_instance(self, instance):
 | 
					    def delete_instance(self, instance):
 | 
				
			||||||
        """ Deletes one instance """
 | 
					        """ Deletes one instance """
 | 
				
			||||||
@@ -378,8 +396,8 @@ class QuerySet(object):
 | 
				
			|||||||
        pkeys = []
 | 
					        pkeys = []
 | 
				
			||||||
        qtypes = []
 | 
					        qtypes = []
 | 
				
			||||||
        def add_column(col):
 | 
					        def add_column(col):
 | 
				
			||||||
            s = '{} {}'.format(col.db_field, col.db_type)
 | 
					            s = '{} {}'.format(col.db_field_name, col.db_type)
 | 
				
			||||||
            if col.primary_key: pkeys.append(col.db_field)
 | 
					            if col.primary_key: pkeys.append(col.db_field_name)
 | 
				
			||||||
            qtypes.append(s)
 | 
					            qtypes.append(s)
 | 
				
			||||||
        #add_column(self.model._columns[self.model._pk_name])
 | 
					        #add_column(self.model._columns[self.model._pk_name])
 | 
				
			||||||
        for name, col in self.model._columns.items():
 | 
					        for name, col in self.model._columns.items():
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,11 +8,11 @@ from cqlengine import query
 | 
				
			|||||||
class TestModel(Model):
 | 
					class TestModel(Model):
 | 
				
			||||||
    test_id = columns.Integer(primary_key=True)
 | 
					    test_id = columns.Integer(primary_key=True)
 | 
				
			||||||
    attempt_id = columns.Integer(primary_key=True)
 | 
					    attempt_id = columns.Integer(primary_key=True)
 | 
				
			||||||
    descriptions = columns.Text()
 | 
					    description = columns.Text()
 | 
				
			||||||
    expected_result = columns.Integer()
 | 
					    expected_result = columns.Integer()
 | 
				
			||||||
    test_result = columns.Integer(index=True)
 | 
					    test_result = columns.Integer(index=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestQuerySet(BaseCassEngTestCase):
 | 
					class TestQuerySetOperation(BaseCassEngTestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_query_filter_parsing(self):
 | 
					    def test_query_filter_parsing(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -54,7 +54,7 @@ class TestQuerySet(BaseCassEngTestCase):
 | 
				
			|||||||
        assert where == 'test_id = :{} AND expected_result >= :{}'.format(*ids)
 | 
					        assert where == 'test_id = :{} AND expected_result >= :{}'.format(*ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_querystring_construction(self):
 | 
					    def test_querystring_generation(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Tests the select querystring creation
 | 
					        Tests the select querystring creation
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -101,3 +101,60 @@ class TestQuerySet(BaseCassEngTestCase):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        Tests that setting only or defer fields that don't exist raises an exception
 | 
					        Tests that setting only or defer fields that don't exist raises an exception
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestQuerySetUsage(BaseCassEngTestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def setUpClass(cls):
 | 
				
			||||||
 | 
					        super(TestQuerySetUsage, cls).setUpClass()
 | 
				
			||||||
 | 
					        try: TestModel.objects._delete_column_family()
 | 
				
			||||||
 | 
					        except: pass
 | 
				
			||||||
 | 
					        TestModel.objects._create_column_family()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30)
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=0, attempt_id=1, description='try2', expected_result=10, test_result=30)
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=0, attempt_id=2, description='try3', expected_result=15, test_result=30)
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=0, attempt_id=3, description='try4', expected_result=20, test_result=25)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=1, attempt_id=0, description='try5', expected_result=5, test_result=25)
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=1, attempt_id=1, description='try6', expected_result=10, test_result=25)
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=1, attempt_id=2, description='try7', expected_result=15, test_result=25)
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=1, attempt_id=3, description='try8', expected_result=20, test_result=20)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=2, attempt_id=0, description='try9', expected_result=50, test_result=40)
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=2, attempt_id=1, description='try10', expected_result=60, test_result=40)
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=2, attempt_id=2, description='try11', expected_result=70, test_result=45)
 | 
				
			||||||
 | 
					        TestModel.objects.create(test_id=2, attempt_id=3, description='try12', expected_result=75, test_result=45)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def tearDownClass(cls):
 | 
				
			||||||
 | 
					        super(TestQuerySetUsage, cls).tearDownClass()
 | 
				
			||||||
 | 
					        TestModel.objects._delete_column_family()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_count(self):
 | 
				
			||||||
 | 
					        q = TestModel.objects(test_id=0)
 | 
				
			||||||
 | 
					        assert q.count() == 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_iteration(self):
 | 
				
			||||||
 | 
					        q = TestModel.objects(test_id=0)
 | 
				
			||||||
 | 
					        #tuple of expected attempt_id, expected_result values
 | 
				
			||||||
 | 
					        import ipdb; ipdb.set_trace()
 | 
				
			||||||
 | 
					        compare_set = set([(0,5), (1,10), (2,15), (3,20)])
 | 
				
			||||||
 | 
					        for t in q:
 | 
				
			||||||
 | 
					            val = t.attempt_id, t.expected_result
 | 
				
			||||||
 | 
					            assert val in compare_set
 | 
				
			||||||
 | 
					            compare_set.remove(val)
 | 
				
			||||||
 | 
					        assert len(compare_set) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        q = TestModel.objects(attempt_id=3)
 | 
				
			||||||
 | 
					        assert len(q) == 3
 | 
				
			||||||
 | 
					        #tuple of expected test_id, expected_result values
 | 
				
			||||||
 | 
					        compare_set = set([(0,20), (1,20), (2,75)])
 | 
				
			||||||
 | 
					        for t in q:
 | 
				
			||||||
 | 
					            val = t.test_id, t.expected_result
 | 
				
			||||||
 | 
					            assert val in compare_set
 | 
				
			||||||
 | 
					            compare_set.remove(val)
 | 
				
			||||||
 | 
					        assert len(compare_set) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user