cqle: attach routing key to mapper statements
PYTHON-535
This commit is contained in:
@@ -19,7 +19,7 @@ import six
|
||||
from uuid import UUID as _UUID
|
||||
|
||||
from cassandra import util
|
||||
from cassandra.cqltypes import SimpleDateType
|
||||
from cassandra.cqltypes import SimpleDateType, _cqltypes, UserType
|
||||
from cassandra.cqlengine import ValidationError
|
||||
from cassandra.cqlengine.functions import get_total_seconds
|
||||
|
||||
@@ -256,6 +256,10 @@ class Column(object):
|
||||
def sub_types(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def cql_type(self):
|
||||
return _cqltypes[self.db_type]
|
||||
|
||||
|
||||
class Blob(Column):
|
||||
"""
|
||||
@@ -666,6 +670,10 @@ class BaseCollectionColumn(Column):
|
||||
def sub_types(self):
|
||||
return self.types
|
||||
|
||||
@property
|
||||
def cql_type(self):
|
||||
return _cqltypes[self.__class__.__name__.lower()].apply_parameters([c.cql_type for c in self.types])
|
||||
|
||||
|
||||
class Tuple(BaseCollectionColumn):
|
||||
"""
|
||||
@@ -877,6 +885,10 @@ class UserDefinedType(Column):
|
||||
def sub_types(self):
|
||||
return list(self.user_type._fields.values())
|
||||
|
||||
@property
|
||||
def cql_type(self):
|
||||
return UserType.apply_parameters([c.cql_type for c in self.user_type._fields.values()])
|
||||
|
||||
|
||||
def resolve_udts(col_def, out_list):
|
||||
for col in col_def.sub_types:
|
||||
|
||||
@@ -379,6 +379,10 @@ class BaseModel(object):
|
||||
return '{0} <{1}>'.format(self.__class__.__name__,
|
||||
', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys()))
|
||||
|
||||
@classmethod
|
||||
def _routing_key_from_values(cls, pk_values, protocol_version):
|
||||
return cls._key_serializer(pk_values, protocol_version)
|
||||
|
||||
@classmethod
|
||||
def _discover_polymorphic_submodels(cls):
|
||||
if not cls._is_polymorphic_base:
|
||||
@@ -867,6 +871,11 @@ class ModelMetaClass(type):
|
||||
partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key)
|
||||
clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key)
|
||||
|
||||
key_cols = [c for c in partition_keys.values()]
|
||||
partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols)
|
||||
key_cql_types = [c.cql_type for c in key_cols]
|
||||
key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)])
|
||||
|
||||
# setup partition key shortcut
|
||||
if len(partition_keys) == 0:
|
||||
if not is_abstract:
|
||||
@@ -910,6 +919,8 @@ class ModelMetaClass(type):
|
||||
attrs['_dynamic_columns'] = {}
|
||||
|
||||
attrs['_partition_keys'] = partition_keys
|
||||
attrs['_partition_key_index'] = partition_key_index
|
||||
attrs['_key_serializer'] = key_serializer
|
||||
attrs['_clustering_keys'] = clustering_keys
|
||||
attrs['_has_counter'] = len(counter_columns) > 0
|
||||
|
||||
|
||||
@@ -84,6 +84,8 @@ class NamedTable(object):
|
||||
|
||||
__partition_keys = None
|
||||
|
||||
_partition_key_index = None
|
||||
|
||||
class DoesNotExist(_DoesNotExist):
|
||||
pass
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import time
|
||||
import six
|
||||
from warnings import warn
|
||||
|
||||
from cassandra.query import SimpleStatement
|
||||
from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin
|
||||
from cassandra.cqlengine import connection
|
||||
from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue
|
||||
@@ -310,11 +311,11 @@ class AbstractQuerySet(object):
|
||||
def column_family_name(self):
|
||||
return self.model.column_family_name()
|
||||
|
||||
def _execute(self, q):
|
||||
def _execute(self, statement):
|
||||
if self._batch:
|
||||
return self._batch.add_query(q)
|
||||
return self._batch.add_query(statement)
|
||||
else:
|
||||
result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
|
||||
result = _execute_statement(self.model, statement, self._consistency, None, self._timeout)
|
||||
if self._if_not_exists or self._if_exists or self._conditional:
|
||||
check_applied(result)
|
||||
return result
|
||||
@@ -1205,14 +1206,14 @@ class DMLQuery(object):
|
||||
self._conditional = conditional
|
||||
self._timeout = timeout
|
||||
|
||||
def _execute(self, q):
|
||||
def _execute(self, statement):
|
||||
if self._batch:
|
||||
return self._batch.add_query(q)
|
||||
return self._batch.add_query(statement)
|
||||
else:
|
||||
tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
|
||||
results = _execute_statement(self.model, statement, self._consistency, None, self._timeout)
|
||||
if self._if_not_exists or self._if_exists or self._conditional:
|
||||
check_applied(tmp)
|
||||
return tmp
|
||||
check_applied(results)
|
||||
return results
|
||||
|
||||
def batch(self, batch_obj):
|
||||
if batch_obj is not None and not isinstance(batch_obj, BatchQuery):
|
||||
@@ -1338,3 +1339,14 @@ class DMLQuery(object):
|
||||
continue
|
||||
ds.add_where(col, EqualsOperator(), val)
|
||||
self._execute(ds)
|
||||
|
||||
|
||||
def _execute_statement(model, statement, consistency_level, fetch_size, timeout):
|
||||
params = statement.get_context()
|
||||
s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=fetch_size)
|
||||
if model._partition_key_index:
|
||||
key_values = statement.partition_key_values(model._partition_key_index)
|
||||
if not any(v is None for v in key_values):
|
||||
parts = model._routing_key_from_values(key_values, connection.get_cluster().protocol_version)
|
||||
s.routing_key = parts
|
||||
return connection.execute(s, params, timeout=timeout)
|
||||
|
||||
@@ -21,7 +21,7 @@ from cassandra.query import FETCH_SIZE_UNSET
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine import UnicodeMixin
|
||||
from cassandra.cqlengine.functions import QueryValue
|
||||
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator
|
||||
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator
|
||||
|
||||
|
||||
class StatementException(Exception):
|
||||
@@ -505,7 +505,7 @@ class BaseCQLStatement(UnicodeMixin):
|
||||
|
||||
def partition_key_values(self, field_index_map):
|
||||
parts = [None] * len(field_index_map)
|
||||
self._update_part_key_values(field_index_map, self.where_clauses, parts)
|
||||
self._update_part_key_values(field_index_map, (w for w in self.where_clauses if isinstance(w, EqualsOperator)), parts)
|
||||
return parts
|
||||
|
||||
def add_where(self, column, operator, value, quote_field=True):
|
||||
|
||||
@@ -61,7 +61,6 @@ class BaseColumnIOTest(BaseCassEngTestCase):
|
||||
|
||||
# create a table with the given column
|
||||
class IOTestModel(Model):
|
||||
table_name = cls.column.db_type + "_io_test_model_{0}".format(uuid4().hex[:8])
|
||||
pkey = cls.column(primary_key=True)
|
||||
data = cls.column()
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from datetime import datetime
|
||||
import time
|
||||
from uuid import uuid1, uuid4
|
||||
import uuid
|
||||
import sys
|
||||
|
||||
from cassandra.cluster import Session
|
||||
from cassandra import InvalidRequest
|
||||
@@ -457,8 +456,8 @@ def test_non_quality_filtering():
|
||||
NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now())
|
||||
NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now())
|
||||
|
||||
qA = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering()
|
||||
num = qA.count()
|
||||
qa = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering()
|
||||
num = qa.count()
|
||||
assert num == 1, num
|
||||
|
||||
|
||||
@@ -473,7 +472,7 @@ class TestQuerySetDistinct(BaseQuerySetUsage):
|
||||
self.assertEqual(len(q), 3)
|
||||
|
||||
def test_distinct_with_filter(self):
|
||||
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1,2])
|
||||
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2])
|
||||
self.assertEqual(len(q), 2)
|
||||
|
||||
def test_distinct_with_non_partition(self):
|
||||
@@ -510,19 +509,19 @@ class TestQuerySetOrdering(BaseQuerySetUsage):
|
||||
def test_ordering_by_non_second_primary_keys_fail(self):
|
||||
# kwarg filtering
|
||||
with self.assertRaises(query.QueryException):
|
||||
q = TestModel.objects(test_id=0).order_by('test_id')
|
||||
TestModel.objects(test_id=0).order_by('test_id')
|
||||
|
||||
# kwarg filtering
|
||||
with self.assertRaises(query.QueryException):
|
||||
q = TestModel.objects(TestModel.test_id == 0).order_by('test_id')
|
||||
TestModel.objects(TestModel.test_id == 0).order_by('test_id')
|
||||
|
||||
def test_ordering_by_non_primary_keys_fails(self):
|
||||
with self.assertRaises(query.QueryException):
|
||||
q = TestModel.objects(test_id=0).order_by('description')
|
||||
TestModel.objects(test_id=0).order_by('description')
|
||||
|
||||
def test_ordering_on_indexed_columns_fails(self):
|
||||
with self.assertRaises(query.QueryException):
|
||||
q = IndexedTestModel.objects(test_id=0).order_by('attempt_id')
|
||||
IndexedTestModel.objects(test_id=0).order_by('attempt_id')
|
||||
|
||||
def test_ordering_on_multiple_clustering_columns(self):
|
||||
TestMultiClusteringModel.create(one=1, two=1, three=4)
|
||||
@@ -673,7 +672,7 @@ class TestQuerySetDelete(BaseQuerySetUsage):
|
||||
TestMultiClusteringModel.objects(one=1, two__gt=3, two__lt=5).delete()
|
||||
self.assertEqual(5, len(TestMultiClusteringModel.objects.all()))
|
||||
|
||||
TestMultiClusteringModel.objects(one=1, two__in=[8,9]).delete()
|
||||
TestMultiClusteringModel.objects(one=1, two__in=[8, 9]).delete()
|
||||
self.assertEqual(3, len(TestMultiClusteringModel.objects.all()))
|
||||
|
||||
TestMultiClusteringModel.objects(one__in=[1], two__gte=0).delete()
|
||||
@@ -878,7 +877,7 @@ class TestValuesList(BaseQuerySetUsage):
|
||||
class TestObjectsProperty(BaseQuerySetUsage):
|
||||
def test_objects_property_returns_fresh_queryset(self):
|
||||
assert TestModel.objects._result_cache is None
|
||||
len(TestModel.objects) # evaluate queryset
|
||||
len(TestModel.objects) # evaluate queryset
|
||||
assert TestModel.objects._result_cache is None
|
||||
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class RoutingTests(unittest.TestCase):
|
||||
select = s.prepare("SELECT token(%s) FROM %s WHERE %s" %
|
||||
(primary_key, table_name, where_clause))
|
||||
|
||||
return (insert, select)
|
||||
return insert, select
|
||||
|
||||
def test_singular_key(self):
|
||||
# string
|
||||
|
||||
Reference in New Issue
Block a user