diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py index f6f843db..1ee60137 100644 --- a/cassandra/cqlengine/columns.py +++ b/cassandra/cqlengine/columns.py @@ -19,7 +19,7 @@ import six from uuid import UUID as _UUID from cassandra import util -from cassandra.cqltypes import SimpleDateType +from cassandra.cqltypes import SimpleDateType, _cqltypes, UserType from cassandra.cqlengine import ValidationError from cassandra.cqlengine.functions import get_total_seconds @@ -256,6 +256,10 @@ class Column(object): def sub_types(self): return [] + @property + def cql_type(self): + return _cqltypes[self.db_type] + class Blob(Column): """ @@ -666,6 +670,10 @@ class BaseCollectionColumn(Column): def sub_types(self): return self.types + @property + def cql_type(self): + return _cqltypes[self.__class__.__name__.lower()].apply_parameters([c.cql_type for c in self.types]) + class Tuple(BaseCollectionColumn): """ @@ -877,6 +885,10 @@ class UserDefinedType(Column): def sub_types(self): return list(self.user_type._fields.values()) + @property + def cql_type(self): + return UserType.apply_parameters([c.cql_type for c in self.user_type._fields.values()]) + def resolve_udts(col_def, out_list): for col in col_def.sub_types: diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index ff732075..78d29901 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -379,6 +379,10 @@ class BaseModel(object): return '{0} <{1}>'.format(self.__class__.__name__, ', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys())) + @classmethod + def _routing_key_from_values(cls, pk_values, protocol_version): + return cls._key_serializer(pk_values, protocol_version) + @classmethod def _discover_polymorphic_submodels(cls): if not cls._is_polymorphic_base: @@ -867,6 +871,11 @@ class ModelMetaClass(type): partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key) + key_cols = [c for c in partition_keys.values()] + partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols) + key_cql_types = [c.cql_type for c in key_cols] + key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)]) + # setup partition key shortcut if len(partition_keys) == 0: if not is_abstract: @@ -910,6 +919,8 @@ class ModelMetaClass(type): attrs['_dynamic_columns'] = {} attrs['_partition_keys'] = partition_keys + attrs['_partition_key_index'] = partition_key_index + attrs['_key_serializer'] = key_serializer attrs['_clustering_keys'] = clustering_keys attrs['_has_counter'] = len(counter_columns) > 0 diff --git a/cassandra/cqlengine/named.py b/cassandra/cqlengine/named.py index 90a0d8fd..07b4c50b 100644 --- a/cassandra/cqlengine/named.py +++ b/cassandra/cqlengine/named.py @@ -84,6 +84,8 @@ class NamedTable(object): __partition_keys = None + _partition_key_index = None + class DoesNotExist(_DoesNotExist): pass diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index 7ac8f362..f2856bf1 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -19,6 +19,7 @@ import time import six from warnings import warn +from cassandra.query import SimpleStatement from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin from cassandra.cqlengine import connection from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue @@ -310,11 +311,11 @@ class AbstractQuerySet(object): def column_family_name(self): return self.model.column_family_name() - def _execute(self, q): + def _execute(self, statement): if self._batch: - return self._batch.add_query(q) + return self._batch.add_query(statement) else: - result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) + result = _execute_statement(self.model, statement, self._consistency, None, self._timeout) if self._if_not_exists or self._if_exists or self._conditional: check_applied(result) return result @@ -1205,14 +1206,14 @@ class DMLQuery(object): self._conditional = conditional self._timeout = timeout - def _execute(self, q): + def _execute(self, statement): if self._batch: - return self._batch.add_query(q) + return self._batch.add_query(statement) else: - tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) + results = _execute_statement(self.model, statement, self._consistency, None, self._timeout) if self._if_not_exists or self._if_exists or self._conditional: - check_applied(tmp) - return tmp + check_applied(results) + return results def batch(self, batch_obj): if batch_obj is not None and not isinstance(batch_obj, BatchQuery): @@ -1338,3 +1339,14 @@ class DMLQuery(object): continue ds.add_where(col, EqualsOperator(), val) self._execute(ds) + + +def _execute_statement(model, statement, consistency_level, fetch_size, timeout): + params = statement.get_context() + s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=fetch_size) + if model._partition_key_index: + key_values = statement.partition_key_values(model._partition_key_index) + if not any(v is None for v in key_values): + parts = model._routing_key_from_values(key_values, connection.get_cluster().protocol_version) + s.routing_key = parts + return connection.execute(s, params, timeout=timeout) diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index 1df7ead1..21044f8d 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -21,7 +21,7 @@ from cassandra.query import FETCH_SIZE_UNSET from cassandra.cqlengine import columns from cassandra.cqlengine import UnicodeMixin from cassandra.cqlengine.functions import QueryValue -from cassandra.cqlengine.operators import BaseWhereOperator, InOperator +from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator class StatementException(Exception): @@ -505,7 +505,7 @@ class BaseCQLStatement(UnicodeMixin): def partition_key_values(self, field_index_map): parts = [None] * len(field_index_map) - self._update_part_key_values(field_index_map, self.where_clauses, parts) + self._update_part_key_values(field_index_map, (w for w in self.where_clauses if isinstance(w, EqualsOperator)), parts) return parts def add_where(self, column, operator, value, quote_field=True): diff --git a/tests/integration/cqlengine/columns/test_value_io.py b/tests/integration/cqlengine/columns/test_value_io.py index 455f80c1..42dc2420 100644 --- a/tests/integration/cqlengine/columns/test_value_io.py +++ b/tests/integration/cqlengine/columns/test_value_io.py @@ -61,7 +61,6 @@ class BaseColumnIOTest(BaseCassEngTestCase): # create a table with the given column class IOTestModel(Model): - table_name = cls.column.db_type + "_io_test_model_{0}".format(uuid4().hex[:8]) pkey = cls.column(primary_key=True) data = cls.column() diff --git a/tests/integration/cqlengine/query/test_queryset.py b/tests/integration/cqlengine/query/test_queryset.py index 0097e604..a7364f2a 100644 --- a/tests/integration/cqlengine/query/test_queryset.py +++ b/tests/integration/cqlengine/query/test_queryset.py @@ -22,7 +22,6 @@ from datetime import datetime import time from uuid import uuid1, uuid4 import uuid -import sys from cassandra.cluster import Session from cassandra import InvalidRequest @@ -457,8 +456,8 @@ def test_non_quality_filtering(): NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now()) NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now()) - qA = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering() - num = qA.count() + qa = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering() + num = qa.count() assert num == 1, num @@ -473,7 +472,7 @@ class TestQuerySetDistinct(BaseQuerySetUsage): self.assertEqual(len(q), 3) def test_distinct_with_filter(self): - q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1,2]) + q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2]) self.assertEqual(len(q), 2) def test_distinct_with_non_partition(self): @@ -510,19 +509,19 @@ class TestQuerySetOrdering(BaseQuerySetUsage): def test_ordering_by_non_second_primary_keys_fail(self): # kwarg filtering with self.assertRaises(query.QueryException): - q = TestModel.objects(test_id=0).order_by('test_id') + TestModel.objects(test_id=0).order_by('test_id') # kwarg filtering with self.assertRaises(query.QueryException): - q = TestModel.objects(TestModel.test_id == 0).order_by('test_id') + TestModel.objects(TestModel.test_id == 0).order_by('test_id') def test_ordering_by_non_primary_keys_fails(self): with self.assertRaises(query.QueryException): - q = TestModel.objects(test_id=0).order_by('description') + TestModel.objects(test_id=0).order_by('description') def test_ordering_on_indexed_columns_fails(self): with self.assertRaises(query.QueryException): - q = IndexedTestModel.objects(test_id=0).order_by('attempt_id') + IndexedTestModel.objects(test_id=0).order_by('attempt_id') def test_ordering_on_multiple_clustering_columns(self): TestMultiClusteringModel.create(one=1, two=1, three=4) @@ -673,7 +672,7 @@ class TestQuerySetDelete(BaseQuerySetUsage): TestMultiClusteringModel.objects(one=1, two__gt=3, two__lt=5).delete() self.assertEqual(5, len(TestMultiClusteringModel.objects.all())) - TestMultiClusteringModel.objects(one=1, two__in=[8,9]).delete() + TestMultiClusteringModel.objects(one=1, two__in=[8, 9]).delete() self.assertEqual(3, len(TestMultiClusteringModel.objects.all())) TestMultiClusteringModel.objects(one__in=[1], two__gte=0).delete() @@ -878,7 +877,7 @@ class TestValuesList(BaseQuerySetUsage): class TestObjectsProperty(BaseQuerySetUsage): def test_objects_property_returns_fresh_queryset(self): assert TestModel.objects._result_cache is None - len(TestModel.objects) # evaluate queryset + len(TestModel.objects) # evaluate queryset assert TestModel.objects._result_cache is None diff --git a/tests/integration/standard/test_routing.py b/tests/integration/standard/test_routing.py index 1ad0a2fa..b22184a0 100644 --- a/tests/integration/standard/test_routing.py +++ b/tests/integration/standard/test_routing.py @@ -74,7 +74,7 @@ class RoutingTests(unittest.TestCase): select = s.prepare("SELECT token(%s) FROM %s WHERE %s" % (primary_key, table_name, where_clause)) - return (insert, select) + return insert, select def test_singular_key(self): # string