cqle: attach routing key to mapper statements
PYTHON-535
This commit is contained in:
		| @@ -19,7 +19,7 @@ import six | ||||
| from uuid import UUID as _UUID | ||||
|  | ||||
| from cassandra import util | ||||
| from cassandra.cqltypes import SimpleDateType | ||||
| from cassandra.cqltypes import SimpleDateType, _cqltypes, UserType | ||||
| from cassandra.cqlengine import ValidationError | ||||
| from cassandra.cqlengine.functions import get_total_seconds | ||||
|  | ||||
| @@ -256,6 +256,10 @@ class Column(object): | ||||
|     def sub_types(self): | ||||
|         return [] | ||||
|  | ||||
|     @property | ||||
|     def cql_type(self): | ||||
|         return _cqltypes[self.db_type] | ||||
|  | ||||
|  | ||||
| class Blob(Column): | ||||
|     """ | ||||
| @@ -666,6 +670,10 @@ class BaseCollectionColumn(Column): | ||||
|     def sub_types(self): | ||||
|         return self.types | ||||
|  | ||||
|     @property | ||||
|     def cql_type(self): | ||||
|         return _cqltypes[self.__class__.__name__.lower()].apply_parameters([c.cql_type for c in self.types]) | ||||
|  | ||||
|  | ||||
| class Tuple(BaseCollectionColumn): | ||||
|     """ | ||||
| @@ -877,6 +885,10 @@ class UserDefinedType(Column): | ||||
|     def sub_types(self): | ||||
|         return list(self.user_type._fields.values()) | ||||
|  | ||||
|     @property | ||||
|     def cql_type(self): | ||||
|         return UserType.apply_parameters([c.cql_type for c in self.user_type._fields.values()]) | ||||
|  | ||||
|  | ||||
| def resolve_udts(col_def, out_list): | ||||
|     for col in col_def.sub_types: | ||||
|   | ||||
| @@ -379,6 +379,10 @@ class BaseModel(object): | ||||
|         return '{0} <{1}>'.format(self.__class__.__name__, | ||||
|                                 ', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys())) | ||||
|  | ||||
|     @classmethod | ||||
|     def _routing_key_from_values(cls, pk_values, protocol_version): | ||||
|         return cls._key_serializer(pk_values, protocol_version) | ||||
|  | ||||
|     @classmethod | ||||
|     def _discover_polymorphic_submodels(cls): | ||||
|         if not cls._is_polymorphic_base: | ||||
| @@ -867,6 +871,11 @@ class ModelMetaClass(type): | ||||
|         partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) | ||||
|         clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key) | ||||
|  | ||||
|         key_cols = [c for c in partition_keys.values()] | ||||
|         partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols) | ||||
|         key_cql_types = [c.cql_type for c in key_cols] | ||||
|         key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)]) | ||||
|  | ||||
|         # setup partition key shortcut | ||||
|         if len(partition_keys) == 0: | ||||
|             if not is_abstract: | ||||
| @@ -910,6 +919,8 @@ class ModelMetaClass(type): | ||||
|         attrs['_dynamic_columns'] = {} | ||||
|  | ||||
|         attrs['_partition_keys'] = partition_keys | ||||
|         attrs['_partition_key_index'] = partition_key_index | ||||
|         attrs['_key_serializer'] = key_serializer | ||||
|         attrs['_clustering_keys'] = clustering_keys | ||||
|         attrs['_has_counter'] = len(counter_columns) > 0 | ||||
|  | ||||
|   | ||||
| @@ -84,6 +84,8 @@ class NamedTable(object): | ||||
|  | ||||
|     __partition_keys = None | ||||
|  | ||||
|     _partition_key_index = None | ||||
|  | ||||
|     class DoesNotExist(_DoesNotExist): | ||||
|         pass | ||||
|  | ||||
|   | ||||
| @@ -19,6 +19,7 @@ import time | ||||
| import six | ||||
| from warnings import warn | ||||
|  | ||||
| from cassandra.query import SimpleStatement | ||||
| from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin | ||||
| from cassandra.cqlengine import connection | ||||
| from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue | ||||
| @@ -310,11 +311,11 @@ class AbstractQuerySet(object): | ||||
|     def column_family_name(self): | ||||
|         return self.model.column_family_name() | ||||
|  | ||||
|     def _execute(self, q): | ||||
|     def _execute(self, statement): | ||||
|         if self._batch: | ||||
|             return self._batch.add_query(q) | ||||
|             return self._batch.add_query(statement) | ||||
|         else: | ||||
|             result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) | ||||
|             result = _execute_statement(self.model, statement, self._consistency, None, self._timeout) | ||||
|             if self._if_not_exists or self._if_exists or self._conditional: | ||||
|                 check_applied(result) | ||||
|             return result | ||||
| @@ -1205,14 +1206,14 @@ class DMLQuery(object): | ||||
|         self._conditional = conditional | ||||
|         self._timeout = timeout | ||||
|  | ||||
|     def _execute(self, q): | ||||
|     def _execute(self, statement): | ||||
|         if self._batch: | ||||
|             return self._batch.add_query(q) | ||||
|             return self._batch.add_query(statement) | ||||
|         else: | ||||
|             tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) | ||||
|             results = _execute_statement(self.model, statement, self._consistency, None, self._timeout) | ||||
|             if self._if_not_exists or self._if_exists or self._conditional: | ||||
|                 check_applied(tmp) | ||||
|             return tmp | ||||
|                 check_applied(results) | ||||
|             return results | ||||
|  | ||||
|     def batch(self, batch_obj): | ||||
|         if batch_obj is not None and not isinstance(batch_obj, BatchQuery): | ||||
| @@ -1338,3 +1339,14 @@ class DMLQuery(object): | ||||
|                 continue | ||||
|             ds.add_where(col, EqualsOperator(), val) | ||||
|         self._execute(ds) | ||||
|  | ||||
|  | ||||
| def _execute_statement(model, statement, consistency_level, fetch_size, timeout): | ||||
|     params = statement.get_context() | ||||
|     s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=fetch_size) | ||||
|     if model._partition_key_index: | ||||
|         key_values = statement.partition_key_values(model._partition_key_index) | ||||
|         if not any(v is None for v in key_values): | ||||
|             parts = model._routing_key_from_values(key_values, connection.get_cluster().protocol_version) | ||||
|             s.routing_key = parts | ||||
|     return connection.execute(s, params, timeout=timeout) | ||||
|   | ||||
| @@ -21,7 +21,7 @@ from cassandra.query import FETCH_SIZE_UNSET | ||||
| from cassandra.cqlengine import columns | ||||
| from cassandra.cqlengine import UnicodeMixin | ||||
| from cassandra.cqlengine.functions import QueryValue | ||||
| from cassandra.cqlengine.operators import BaseWhereOperator, InOperator | ||||
| from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator | ||||
|  | ||||
|  | ||||
| class StatementException(Exception): | ||||
| @@ -505,7 +505,7 @@ class BaseCQLStatement(UnicodeMixin): | ||||
|  | ||||
|     def partition_key_values(self, field_index_map): | ||||
|         parts = [None] * len(field_index_map) | ||||
|         self._update_part_key_values(field_index_map, self.where_clauses, parts) | ||||
|         self._update_part_key_values(field_index_map, (w for w in self.where_clauses if isinstance(w, EqualsOperator)), parts) | ||||
|         return parts | ||||
|  | ||||
|     def add_where(self, column, operator, value, quote_field=True): | ||||
|   | ||||
| @@ -61,7 +61,6 @@ class BaseColumnIOTest(BaseCassEngTestCase): | ||||
|  | ||||
|         # create a table with the given column | ||||
|         class IOTestModel(Model): | ||||
|             table_name = cls.column.db_type + "_io_test_model_{0}".format(uuid4().hex[:8]) | ||||
|             pkey = cls.column(primary_key=True) | ||||
|             data = cls.column() | ||||
|  | ||||
|   | ||||
| @@ -22,7 +22,6 @@ from datetime import datetime | ||||
| import time | ||||
| from uuid import uuid1, uuid4 | ||||
| import uuid | ||||
| import sys | ||||
|  | ||||
| from cassandra.cluster import Session | ||||
| from cassandra import InvalidRequest | ||||
| @@ -457,8 +456,8 @@ def test_non_quality_filtering(): | ||||
|     NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now()) | ||||
|     NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now()) | ||||
|  | ||||
|     qA = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering() | ||||
|     num = qA.count() | ||||
|     qa = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering() | ||||
|     num = qa.count() | ||||
|     assert num == 1, num | ||||
|  | ||||
|  | ||||
| @@ -473,7 +472,7 @@ class TestQuerySetDistinct(BaseQuerySetUsage): | ||||
|         self.assertEqual(len(q), 3) | ||||
|  | ||||
|     def test_distinct_with_filter(self): | ||||
|         q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1,2]) | ||||
|         q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2]) | ||||
|         self.assertEqual(len(q), 2) | ||||
|  | ||||
|     def test_distinct_with_non_partition(self): | ||||
| @@ -510,19 +509,19 @@ class TestQuerySetOrdering(BaseQuerySetUsage): | ||||
|     def test_ordering_by_non_second_primary_keys_fail(self): | ||||
|         # kwarg filtering | ||||
|         with self.assertRaises(query.QueryException): | ||||
|             q = TestModel.objects(test_id=0).order_by('test_id') | ||||
|             TestModel.objects(test_id=0).order_by('test_id') | ||||
|  | ||||
|         # kwarg filtering | ||||
|         with self.assertRaises(query.QueryException): | ||||
|             q = TestModel.objects(TestModel.test_id == 0).order_by('test_id') | ||||
|             TestModel.objects(TestModel.test_id == 0).order_by('test_id') | ||||
|  | ||||
|     def test_ordering_by_non_primary_keys_fails(self): | ||||
|         with self.assertRaises(query.QueryException): | ||||
|             q = TestModel.objects(test_id=0).order_by('description') | ||||
|             TestModel.objects(test_id=0).order_by('description') | ||||
|  | ||||
|     def test_ordering_on_indexed_columns_fails(self): | ||||
|         with self.assertRaises(query.QueryException): | ||||
|             q = IndexedTestModel.objects(test_id=0).order_by('attempt_id') | ||||
|             IndexedTestModel.objects(test_id=0).order_by('attempt_id') | ||||
|  | ||||
|     def test_ordering_on_multiple_clustering_columns(self): | ||||
|         TestMultiClusteringModel.create(one=1, two=1, three=4) | ||||
| @@ -673,7 +672,7 @@ class TestQuerySetDelete(BaseQuerySetUsage): | ||||
|         TestMultiClusteringModel.objects(one=1, two__gt=3, two__lt=5).delete() | ||||
|         self.assertEqual(5, len(TestMultiClusteringModel.objects.all())) | ||||
|  | ||||
|         TestMultiClusteringModel.objects(one=1, two__in=[8,9]).delete() | ||||
|         TestMultiClusteringModel.objects(one=1, two__in=[8, 9]).delete() | ||||
|         self.assertEqual(3, len(TestMultiClusteringModel.objects.all())) | ||||
|  | ||||
|         TestMultiClusteringModel.objects(one__in=[1], two__gte=0).delete() | ||||
| @@ -878,7 +877,7 @@ class TestValuesList(BaseQuerySetUsage): | ||||
| class TestObjectsProperty(BaseQuerySetUsage): | ||||
|     def test_objects_property_returns_fresh_queryset(self): | ||||
|         assert TestModel.objects._result_cache is None | ||||
|         len(TestModel.objects) # evaluate queryset | ||||
|         len(TestModel.objects)  # evaluate queryset | ||||
|         assert TestModel.objects._result_cache is None | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -74,7 +74,7 @@ class RoutingTests(unittest.TestCase): | ||||
|         select = s.prepare("SELECT token(%s) FROM %s WHERE %s" % | ||||
|                            (primary_key, table_name, where_clause)) | ||||
|  | ||||
|         return (insert, select) | ||||
|         return insert, select | ||||
|  | ||||
|     def test_singular_key(self): | ||||
|         # string | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Adam Holmberg
					Adam Holmberg