diff --git a/cqlengine/columns.py b/cqlengine/columns.py index 3c9bad45..ebf8cbc2 100644 --- a/cqlengine/columns.py +++ b/cqlengine/columns.py @@ -60,7 +60,7 @@ class Column(object): instance_counter = 0 - def __init__(self, primary_key=False, index=False, db_field=None, default=None, required=True): + def __init__(self, primary_key=False, partition_key=False, index=False, db_field=None, default=None, required=True, clustering_order=None): """ :param primary_key: bool flag, indicates this column is a primary key. The first primary key defined on a model is the partition key, all others are cluster keys @@ -69,15 +69,13 @@ class Column(object): :param default: the default value, can be a value or a callable (no args) :param required: boolean, is the field required? """ - self.primary_key = primary_key + self.partition_key = partition_key + self.primary_key = partition_key or primary_key self.index = index self.db_field = db_field self.default = default self.required = required - - #only the model meta class should touch this - self._partition_key = False - + self.clustering_order = clustering_order #the column name in the model definition self.column_name = None @@ -137,7 +135,7 @@ class Column(object): """ Returns a column definition for CQL table definition """ - return '"{}" {}'.format(self.db_field_name, self.db_type) + return '{} {}'.format(self.cql, self.db_type) def set_column_name(self, name): """ @@ -156,6 +154,13 @@ class Column(object): """ Returns the name of the cql index """ return 'index_{}'.format(self.db_field_name) + @property + def cql(self): + return self.get_cql() + + def get_cql(self): + return '"{}"'.format(self.db_field_name) + class Bytes(Column): db_type = 'blob' @@ -645,4 +650,26 @@ class Map(BaseContainerColumn): return del_statements +class _PartitionKeys(Column): + class value_manager(BaseValueManager): + pass + + def __init__(self, model): + self.model = model + +class _PartitionKeysToken(Column): + """ + virtual column representing token of partition columns. + Used by filter(pk__token=Token(...)) filters + """ + + def __init__(self, model): + self.partition_columns = model._partition_keys.values() + super(_PartitionKeysToken, self).__init__(partition_key=True) + + def to_database(self, value): + raise NotImplementedError + + def get_cql(self): + return "token({})".format(", ".join(c.cql for c in self.partition_columns)) diff --git a/cqlengine/functions.py b/cqlengine/functions.py index 3f023e75..94784180 100644 --- a/cqlengine/functions.py +++ b/cqlengine/functions.py @@ -1,31 +1,40 @@ from datetime import datetime +from uuid import uuid1 from cqlengine.exceptions import ValidationError -class BaseQueryFunction(object): +class QueryValue(object): + """ + Base class for query filter values. Subclasses of these classes can + be passed into .filter() keyword args + """ + + _cql_string = ':{}' + + def __init__(self, value, identifier=None): + self.value = value + self.identifier = uuid1().hex if identifier is None else identifier + + def get_cql(self): + return self._cql_string.format(self.identifier) + + def get_value(self): + return self.value + + def get_dict(self, column): + return {self.identifier: column.to_database(self.get_value())} + + @property + def cql(self): + return self.get_cql() + +class BaseQueryFunction(QueryValue): """ Base class for filtering functions. Subclasses of these classes can be passed into .filter() and will be translated into CQL functions in the resulting query """ - _cql_string = None - - def __init__(self, value): - self.value = value - - def to_cql(self, value_id): - """ - Returns a function for cql with the value id as it's argument - """ - return self._cql_string.format(value_id) - - def get_value(self): - raise NotImplementedError - - def format_cql(self, field, operator, value_id): - return '"{}" {} {}'.format(field, operator, self.to_cql(value_id)) - class MinTimeUUID(BaseQueryFunction): _cql_string = 'MinTimeUUID(:{})' @@ -61,10 +70,19 @@ class MaxTimeUUID(BaseQueryFunction): return long((self.value - epoch).total_seconds() * 1000) class Token(BaseQueryFunction): - _cql_string = 'token(:{})' - def format_cql(self, field, operator, value_id): - return 'token("{}") {} {}'.format(field, operator, self.to_cql(value_id)) + def __init__(self, *values): + if len(values) == 1 and isinstance(values[0], (list, tuple)): + values = values[0] + super(Token, self).__init__(values, [uuid1().hex for i in values]) + + def get_dict(self, column): + items = zip(self.identifier, self.value, column.partition_columns) + return dict( + (id, col.to_database(val)) for id, val, col in items + ) + + def get_cql(self): + token_args = ', '.join(':{}'.format(id) for id in self.identifier) + return "token({})".format(token_args) - def get_value(self): - return self.value diff --git a/cqlengine/management.py b/cqlengine/management.py index 67c34f68..f5f6a649 100644 --- a/cqlengine/management.py +++ b/cqlengine/management.py @@ -61,20 +61,29 @@ def create_table(model, create_missing_keyspace=True): #add column types pkeys = [] + ckeys = [] qtypes = [] def add_column(col): s = col.get_column_def() - if col.primary_key: pkeys.append('"{}"'.format(col.db_field_name)) + if col.primary_key: + keys = (pkeys if col.partition_key else ckeys) + keys.append('"{}"'.format(col.db_field_name)) qtypes.append(s) for name, col in model._columns.items(): add_column(col) - qtypes.append('PRIMARY KEY ({})'.format(', '.join(pkeys))) + qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or '')) qs += ['({})'.format(', '.join(qtypes))] - + + with_qs = ['read_repair_chance = {}'.format(model.read_repair_chance)] + + _order = ["%s %s" % (c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()] + if _order: + with_qs.append("clustering order by ({})".format(', '.join(_order))) + # add read_repair_chance - qs += ['WITH read_repair_chance = {}'.format(model.read_repair_chance)] + qs += ['WITH {}'.format(' AND '.join(with_qs))] qs = ' '.join(qs) try: diff --git a/cqlengine/models.py b/cqlengine/models.py index aec06b4f..f06042d7 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -102,10 +102,10 @@ class BaseModel(object): if not include_keyspace: return cf_name return '{}.{}'.format(cls._get_keyspace(), cf_name) - @property - def pk(self): - """ Returns the object's primary key """ - return getattr(self, self._pk_name) + #@property + #def pk(self): + # """ Returns the object's primary key """ + # return getattr(self, self._pk_name) def validate(self): """ Cleans and validates the field values """ @@ -174,7 +174,6 @@ class ModelMetaClass(type): column_dict = OrderedDict() primary_keys = OrderedDict() pk_name = None - primary_key = None #get inherited properties inherited_columns = OrderedDict() @@ -210,24 +209,40 @@ class ModelMetaClass(type): k,v = 'id', columns.UUID(primary_key=True) column_definitions = [(k,v)] + column_definitions + has_partition_keys = any(v.partition_key for (k, v) in column_definitions) + #TODO: check that the defined columns don't conflict with any of the Model API's existing attributes/methods #transform column definitions for k,v in column_definitions: - if pk_name is None and v.primary_key: - pk_name = k - primary_key = v - v._partition_key = True + if not has_partition_keys and v.primary_key: + v.partition_key = True + has_partition_keys = True _transform_column(k,v) - - #setup primary key shortcut - if pk_name != 'pk': - attrs['pk'] = attrs[pk_name] - #check for duplicate column names + partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) + clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key) + + #setup partition key shortcut + assert partition_keys + if len(partition_keys) == 1: + pk_name = partition_keys.keys()[0] + attrs['pk'] = attrs[pk_name] + else: + # composite partition key case + _get = lambda self: tuple(self._values[c].getval() for c in partition_keys.keys()) + _set = lambda self, val: tuple(self._values[c].setval(v) for (c, v) in zip(partition_keys.keys(), val)) + attrs['pk'] = property(_get, _set) + + # some validation col_names = set() for v in column_dict.values(): + # check for duplicate column names if v.db_field_name in col_names: raise ModelException("{} defines the column {} more than once".format(name, v.db_field_name)) + if v.clustering_order and not (v.primary_key and not v.partition_key): + raise ModelException("clustering_order may be specified only for clustering primary keys") + if v.clustering_order and v.clustering_order.lower() not in ('asc', 'desc'): + raise ModelException("invalid clustering order {} for column {}".format(repr(v.clustering_order), v.db_field_name)) col_names.add(v.db_field_name) #create db_name -> model name map for loading @@ -244,9 +259,11 @@ class ModelMetaClass(type): attrs['_defined_columns'] = defined_columns attrs['_db_map'] = db_map attrs['_pk_name'] = pk_name - attrs['_primary_key'] = primary_key attrs['_dynamic_columns'] = {} + attrs['_partition_keys'] = partition_keys + attrs['_clustering_keys'] = clustering_keys + #create the class and add a QuerySet to it klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs) klass.objects = QuerySet(klass) diff --git a/cqlengine/query.py b/cqlengine/query.py index a4690df0..19974eeb 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -4,11 +4,11 @@ from datetime import datetime from hashlib import md5 from time import time from uuid import uuid1 -from cqlengine import BaseContainerColumn, BaseValueManager, Map +from cqlengine import BaseContainerColumn, BaseValueManager, Map, columns from cqlengine.connection import connection_manager from cqlengine.exceptions import CQLEngineException -from cqlengine.functions import BaseQueryFunction, Token +from cqlengine.functions import QueryValue, Token #CQL 3 reference: #http://www.datastax.com/docs/1.1/references/cql/index @@ -24,14 +24,16 @@ class QueryOperator(object): # The comparator symbol this operator uses in cql cql_symbol = None + QUERY_VALUE_WRAPPER = QueryValue + def __init__(self, column, value): self.column = column self.value = value - #the identifier is a unique key that will be used in string - #replacement on query strings, it's created from a hash - #of this object's id and the time - self.identifier = uuid1().hex + if isinstance(value, QueryValue): + self.query_value = value + else: + self.query_value = self.QUERY_VALUE_WRAPPER(value) #perform validation on this operator self.validate_operator() @@ -41,12 +43,8 @@ class QueryOperator(object): def cql(self): """ Returns this operator's portion of the WHERE clause - :param valname: the dict key that this operator's compare value will be found in """ - if isinstance(self.value, BaseQueryFunction): - return self.value.format_cql(self.column.db_field_name, self.cql_symbol, self.identifier) - else: - return '"{}" {} :{}'.format(self.column.db_field_name, self.cql_symbol, self.identifier) + return '{} {} {}'.format(self.column.cql, self.cql_symbol, self.query_value.cql) def validate_operator(self): """ @@ -81,10 +79,7 @@ class QueryOperator(object): this should return the dict: {'colval':} SELECT * FROM column_family WHERE colname=:colval """ - if isinstance(self.value, BaseQueryFunction): - return {self.identifier: self.column.to_database(self.value.get_value())} - else: - return {self.identifier: self.column.to_database(self.value)} + return self.query_value.get_dict(self.column) @classmethod def get_operator(cls, symbol): @@ -102,34 +97,34 @@ class QueryOperator(object): except KeyError: raise QueryOperatorException("{} doesn't map to a QueryOperator".format(symbol)) + def __eq__(self, op): + return self.__class__ is op.__class__ and self.column == op.column and self.value == op.value + + def __ne__(self, op): + return not (self == op) + class EqualsOperator(QueryOperator): symbol = 'EQ' cql_symbol = '=' +class IterableQueryValue(QueryValue): + def __init__(self, value): + try: + super(IterableQueryValue, self).__init__(value, [uuid1().hex for i in value]) + except TypeError: + raise QueryException("in operator arguments must be iterable, {} found".format(value)) + + def get_dict(self, column): + return dict((i, column.to_database(v)) for (i, v) in zip(self.identifier, self.value)) + + def get_cql(self): + return '({})'.format(', '.join(':{}'.format(i) for i in self.identifier)) + class InOperator(EqualsOperator): symbol = 'IN' cql_symbol = 'IN' - class Quoter(object): - """ - contains a single value, which will quote itself for CQL insertion statements - """ - def __init__(self, value): - self.value = value - - def __str__(self): - from cql.query import cql_quote as cq - return '(' + ', '.join([cq(v) for v in self.value]) + ')' - - def get_dict(self): - if isinstance(self.value, BaseQueryFunction): - return {self.identifier: self.column.to_database(self.value.get_value())} - else: - try: - values = [v for v in self.value] - except TypeError: - raise QueryException("in operator arguments must be iterable, {} found".format(self.value)) - return {self.identifier: self.Quoter([self.column.to_database(v) for v in self.value])} + QUERY_VALUE_WRAPPER = IterableQueryValue class GreaterThanOperator(QueryOperator): symbol = "GT" @@ -286,9 +281,9 @@ class QuerySet(object): if not self._allow_filtering: #if the query is not on an indexed field if not any([w.column.index for w in equal_ops]): - if not any([w.column._partition_key for w in equal_ops]) and not token_ops: + if not any([w.column.partition_key for w in equal_ops]) and not token_ops: raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset') - if any(not w.column._partition_key for w in token_ops): + if any(not w.column.partition_key for w in token_ops): raise QueryException('The token() function is only supported on the partition key') @@ -314,7 +309,7 @@ class QuerySet(object): if self._defer_fields: fields = [f for f in fields if f not in self._defer_fields] elif self._only_fields: - fields = [f for f in fields if f in self._only_fields] + fields = self._only_fields db_fields = [self.model._columns[f].db_field_name for f in fields] qs = ['SELECT {}'.format(', '.join(['"{}"'.format(f) for f in db_fields]))] @@ -449,7 +444,7 @@ class QuerySet(object): __ :returns: colname, op tuple """ - statement = arg.split('__') + statement = arg.rsplit('__', 1) if len(statement) == 1: return arg, None elif len(statement) == 2: @@ -466,7 +461,10 @@ class QuerySet(object): try: column = self.model._columns[col_name] except KeyError: - raise QueryException("Can't resolve column name: '{}'".format(col_name)) + if col_name == 'pk__token': + column = columns._PartitionKeysToken(self.model) + else: + raise QueryException("Can't resolve column name: '{}'".format(col_name)) #get query operator, or use equals if not supplied operator_class = QueryOperator.get_operator(col_op or 'EQ') @@ -640,6 +638,12 @@ class QuerySet(object): clone._flat_values_list = flat return clone + def __eq__(self, q): + return self._where == q._where + + def __ne__(self, q): + return not (self != q) + class DMLQuery(object): """ A query object used for queries performing inserts, updates, or deletes diff --git a/cqlengine/tests/model/test_class_construction.py b/cqlengine/tests/model/test_class_construction.py index 77b1d737..02d51f56 100644 --- a/cqlengine/tests/model/test_class_construction.py +++ b/cqlengine/tests/model/test_class_construction.py @@ -117,7 +117,29 @@ class TestModelClassFunction(BaseCassEngTestCase): """ Test that metadata defined in one class, is not inherited by subclasses """ - + + def test_partition_keys(self): + """ + Test compound partition key definition + """ + class ModelWithPartitionKeys(cqlengine.Model): + c1 = cqlengine.Text(primary_key=True) + p1 = cqlengine.Text(partition_key=True) + p2 = cqlengine.Text(partition_key=True) + + cols = ModelWithPartitionKeys._columns + + self.assertTrue(cols['c1'].primary_key) + self.assertFalse(cols['c1'].partition_key) + + self.assertTrue(cols['p1'].primary_key) + self.assertTrue(cols['p1'].partition_key) + self.assertTrue(cols['p2'].primary_key) + self.assertTrue(cols['p2'].partition_key) + + obj = ModelWithPartitionKeys(p1='a', p2='b') + self.assertEquals(obj.pk, ('a', 'b')) + class TestManualTableNaming(BaseCassEngTestCase): class RenamedTest(cqlengine.Model): diff --git a/cqlengine/tests/model/test_clustering_order.py b/cqlengine/tests/model/test_clustering_order.py new file mode 100644 index 00000000..92d8e067 --- /dev/null +++ b/cqlengine/tests/model/test_clustering_order.py @@ -0,0 +1,35 @@ +import random +from cqlengine.tests.base import BaseCassEngTestCase + +from cqlengine.management import create_table +from cqlengine.management import delete_table +from cqlengine.models import Model +from cqlengine import columns + +class TestModel(Model): + id = columns.Integer(primary_key=True) + clustering_key = columns.Integer(primary_key=True, clustering_order='desc') + +class TestClusteringOrder(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestClusteringOrder, cls).setUpClass() + create_table(TestModel) + + @classmethod + def tearDownClass(cls): + super(TestClusteringOrder, cls).tearDownClass() + delete_table(TestModel) + + def test_clustering_order(self): + """ + Tests that models can be saved and retrieved + """ + items = list(range(20)) + random.shuffle(items) + for i in items: + TestModel.create(id=1, clustering_key=i) + + values = list(TestModel.objects.values_list('clustering_key', flat=True)) + self.assertEquals(values, sorted(items, reverse=True)) diff --git a/cqlengine/tests/query/test_queryoperators.py b/cqlengine/tests/query/test_queryoperators.py index 5db4a299..8f5243fb 100644 --- a/cqlengine/tests/query/test_queryoperators.py +++ b/cqlengine/tests/query/test_queryoperators.py @@ -17,7 +17,7 @@ class TestQuerySetOperation(BaseCassEngTestCase): col.set_column_name('time') qry = query.EqualsOperator(col, functions.MaxTimeUUID(now)) - assert qry.cql == '"time" = MaxTimeUUID(:{})'.format(qry.identifier) + assert qry.cql == '"time" = MaxTimeUUID(:{})'.format(qry.value.identifier) def test_mintimeuuid_function(self): """ @@ -28,7 +28,23 @@ class TestQuerySetOperation(BaseCassEngTestCase): col.set_column_name('time') qry = query.EqualsOperator(col, functions.MinTimeUUID(now)) - assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.identifier) + assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.value.identifier) + def test_token_function(self): + class TestModel(Model): + p1 = columns.Text(partition_key=True) + p2 = columns.Text(partition_key=True) + + func = functions.Token('a', 'b') + + q = TestModel.objects.filter(pk__token__gt=func) + self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier)) + + # Token(tuple()) is also possible for convinience + # it (allows for Token(obj.pk) syntax) + func = functions.Token(('a', 'b')) + + q = TestModel.objects.filter(pk__token__gt=func) + self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier)) diff --git a/cqlengine/tests/query/test_queryset.py b/cqlengine/tests/query/test_queryset.py index a77d7502..79ca5af0 100644 --- a/cqlengine/tests/query/test_queryset.py +++ b/cqlengine/tests/query/test_queryset.py @@ -64,12 +64,12 @@ class TestQuerySetOperation(BaseCassEngTestCase): Tests the where clause creation """ query1 = TestModel.objects(test_id=5) - ids = [o.identifier for o in query1._where] + ids = [o.query_value.identifier for o in query1._where] where = query1._where_clause() assert where == '"test_id" = :{}'.format(*ids) query2 = query1.filter(expected_result__gte=1) - ids = [o.identifier for o in query2._where] + ids = [o.query_value.identifier for o in query2._where] where = query2._where_clause() assert where == '"test_id" = :{} AND "expected_result" >= :{}'.format(*ids) @@ -470,5 +470,12 @@ class TestInOperator(BaseQuerySetUsage): assert q.count() == 8 +class TestValuesList(BaseQuerySetUsage): + def test_values_list(self): + q = TestModel.objects.filter(test_id=0, attempt_id=1) + item = q.values_list('test_id', 'attempt_id', 'description', 'expected_result', 'test_result').first() + assert item == [0, 1, 'try2', 10, 30] + item = q.values_list('expected_result', flat=True).first() + assert item == 10 diff --git a/docs/topics/columns.rst b/docs/topics/columns.rst index 0537e3bc..0779c433 100644 --- a/docs/topics/columns.rst +++ b/docs/topics/columns.rst @@ -145,12 +145,16 @@ Column Options If True, this column is created as a primary key field. A model can have multiple primary keys. Defaults to False. - *In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first primary key is the partition key, and all others are clustering keys.* + *In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first primary key is the partition key, and all others are clustering keys, unless partition keys are specified manually using* :attr:`BaseColumn.partition_key` + + .. attribute:: BaseColumn.partition_key + + If True, this column is created as partition primary key. There may be many partition keys defined, forming *composite partition key* .. attribute:: BaseColumn.index If True, an index will be created for this column. Defaults to False. - + *Note: Indexes can only be created on models with one primary key* .. attribute:: BaseColumn.db_field @@ -165,3 +169,6 @@ Column Options If True, this model cannot be saved without a value defined for this column. Defaults to True. Primary key fields cannot have their required fields set to False. + .. attribute:: BaseColumn.clustering_order + + Defines CLUSTERING ORDER for this column (valid choices are "asc" (default) or "desc"). It may be specified only for clustering primary keys - more: http://www.datastax.com/docs/1.2/cql_cli/cql/CREATE_TABLE#using-clustering-order diff --git a/docs/topics/models.rst b/docs/topics/models.rst index 0dd2c19c..52eb9061 100644 --- a/docs/topics/models.rst +++ b/docs/topics/models.rst @@ -66,7 +66,10 @@ Column Options :attr:`~cqlengine.columns.BaseColumn.primary_key` If True, this column is created as a primary key field. A model can have multiple primary keys. Defaults to False. - *In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first primary key is the partition key, and all others are clustering keys.* + *In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first primary key is the partition key, and all others are clustering keys, unless partition keys are specified manually using* :attr:`~cqlengine.columns.BaseColumn.partition_key` + + :attr:`~cqlengine.columns.BaseColumn.partition_key` + If True, this column is created as partition primary key. There may be many partition keys defined, forming *composite partition key* :attr:`~cqlengine.columns.BaseColumn.index` If True, an index will be created for this column. Defaults to False. diff --git a/docs/topics/queryset.rst b/docs/topics/queryset.rst index 6733080c..8490ce3d 100644 --- a/docs/topics/queryset.rst +++ b/docs/topics/queryset.rst @@ -178,6 +178,26 @@ TimeUUID Functions DataStream.filter(time__gt=cqlengine.MinTimeUUID(min_time), time__lt=cqlengine.MaxTimeUUID(max_time)) +Token Function +============== + + Token functon may be used only on special, virtual column pk__token, representing token of partition key (it also works for composite partition keys). + Cassandra orders returned items by value of partition key token, so using cqlengine.Token we can easy paginate through all table rows. + + *Example* + + .. code-block:: python + + class Items(Model): + id = cqlengine.Text(primary_key=True) + data = cqlengine.Bytes() + + query = Items.objects.all().limit(10) + + first_page = list(query); + last = first_page[-1] + next_page = list(query.filter(pk__token__gt=cqlengine.Token(last.pk))) + QuerySets are imutable ====================== @@ -213,6 +233,13 @@ Ordering QuerySets *For instance, given our Automobile model, year is the only column we can order on.* +Values Lists +============ + + There is a special QuerySet's method ``.values_list()`` - when called, QuerySet returns lists of values instead of model instances. It may significantly speedup things with lower memory footprint for large responses. + Each tuple contains the value from the respective field passed into the ``values_list()`` call — so the first item is the first field, etc. For example: + + Batch Queries ===============