cqle: attach routing key to mapper statements

PYTHON-535
This commit is contained in:
Adam Holmberg
2016-04-04 12:39:52 -05:00
parent c4c0c839f1
commit 7a1fde1a69
8 changed files with 58 additions and 23 deletions

View File

@@ -19,7 +19,7 @@ import six
from uuid import UUID as _UUID
from cassandra import util
from cassandra.cqltypes import SimpleDateType
from cassandra.cqltypes import SimpleDateType, _cqltypes, UserType
from cassandra.cqlengine import ValidationError
from cassandra.cqlengine.functions import get_total_seconds
@@ -256,6 +256,10 @@ class Column(object):
def sub_types(self):
return []
@property
def cql_type(self):
return _cqltypes[self.db_type]
class Blob(Column):
"""
@@ -666,6 +670,10 @@ class BaseCollectionColumn(Column):
def sub_types(self):
return self.types
@property
def cql_type(self):
return _cqltypes[self.__class__.__name__.lower()].apply_parameters([c.cql_type for c in self.types])
class Tuple(BaseCollectionColumn):
"""
@@ -877,6 +885,10 @@ class UserDefinedType(Column):
def sub_types(self):
return list(self.user_type._fields.values())
@property
def cql_type(self):
return UserType.apply_parameters([c.cql_type for c in self.user_type._fields.values()])
def resolve_udts(col_def, out_list):
for col in col_def.sub_types:

View File

@@ -379,6 +379,10 @@ class BaseModel(object):
return '{0} <{1}>'.format(self.__class__.__name__,
', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys()))
@classmethod
def _routing_key_from_values(cls, pk_values, protocol_version):
return cls._key_serializer(pk_values, protocol_version)
@classmethod
def _discover_polymorphic_submodels(cls):
if not cls._is_polymorphic_base:
@@ -867,6 +871,11 @@ class ModelMetaClass(type):
partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key)
clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key)
key_cols = [c for c in partition_keys.values()]
partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols)
key_cql_types = [c.cql_type for c in key_cols]
key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)])
# setup partition key shortcut
if len(partition_keys) == 0:
if not is_abstract:
@@ -910,6 +919,8 @@ class ModelMetaClass(type):
attrs['_dynamic_columns'] = {}
attrs['_partition_keys'] = partition_keys
attrs['_partition_key_index'] = partition_key_index
attrs['_key_serializer'] = key_serializer
attrs['_clustering_keys'] = clustering_keys
attrs['_has_counter'] = len(counter_columns) > 0

View File

@@ -84,6 +84,8 @@ class NamedTable(object):
__partition_keys = None
_partition_key_index = None
class DoesNotExist(_DoesNotExist):
pass

View File

@@ -19,6 +19,7 @@ import time
import six
from warnings import warn
from cassandra.query import SimpleStatement
from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin
from cassandra.cqlengine import connection
from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue
@@ -310,11 +311,11 @@ class AbstractQuerySet(object):
def column_family_name(self):
return self.model.column_family_name()
def _execute(self, q):
def _execute(self, statement):
if self._batch:
return self._batch.add_query(q)
return self._batch.add_query(statement)
else:
result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
result = _execute_statement(self.model, statement, self._consistency, None, self._timeout)
if self._if_not_exists or self._if_exists or self._conditional:
check_applied(result)
return result
@@ -1205,14 +1206,14 @@ class DMLQuery(object):
self._conditional = conditional
self._timeout = timeout
def _execute(self, q):
def _execute(self, statement):
if self._batch:
return self._batch.add_query(q)
return self._batch.add_query(statement)
else:
tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
results = _execute_statement(self.model, statement, self._consistency, None, self._timeout)
if self._if_not_exists or self._if_exists or self._conditional:
check_applied(tmp)
return tmp
check_applied(results)
return results
def batch(self, batch_obj):
if batch_obj is not None and not isinstance(batch_obj, BatchQuery):
@@ -1338,3 +1339,14 @@ class DMLQuery(object):
continue
ds.add_where(col, EqualsOperator(), val)
self._execute(ds)
def _execute_statement(model, statement, consistency_level, fetch_size, timeout):
params = statement.get_context()
s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=fetch_size)
if model._partition_key_index:
key_values = statement.partition_key_values(model._partition_key_index)
if not any(v is None for v in key_values):
parts = model._routing_key_from_values(key_values, connection.get_cluster().protocol_version)
s.routing_key = parts
return connection.execute(s, params, timeout=timeout)

View File

@@ -21,7 +21,7 @@ from cassandra.query import FETCH_SIZE_UNSET
from cassandra.cqlengine import columns
from cassandra.cqlengine import UnicodeMixin
from cassandra.cqlengine.functions import QueryValue
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator
class StatementException(Exception):
@@ -505,7 +505,7 @@ class BaseCQLStatement(UnicodeMixin):
def partition_key_values(self, field_index_map):
parts = [None] * len(field_index_map)
self._update_part_key_values(field_index_map, self.where_clauses, parts)
self._update_part_key_values(field_index_map, (w for w in self.where_clauses if isinstance(w, EqualsOperator)), parts)
return parts
def add_where(self, column, operator, value, quote_field=True):

View File

@@ -61,7 +61,6 @@ class BaseColumnIOTest(BaseCassEngTestCase):
# create a table with the given column
class IOTestModel(Model):
table_name = cls.column.db_type + "_io_test_model_{0}".format(uuid4().hex[:8])
pkey = cls.column(primary_key=True)
data = cls.column()

View File

@@ -22,7 +22,6 @@ from datetime import datetime
import time
from uuid import uuid1, uuid4
import uuid
import sys
from cassandra.cluster import Session
from cassandra import InvalidRequest
@@ -457,8 +456,8 @@ def test_non_quality_filtering():
NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now())
NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now())
qA = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering()
num = qA.count()
qa = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering()
num = qa.count()
assert num == 1, num
@@ -473,7 +472,7 @@ class TestQuerySetDistinct(BaseQuerySetUsage):
self.assertEqual(len(q), 3)
def test_distinct_with_filter(self):
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1,2])
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2])
self.assertEqual(len(q), 2)
def test_distinct_with_non_partition(self):
@@ -510,19 +509,19 @@ class TestQuerySetOrdering(BaseQuerySetUsage):
def test_ordering_by_non_second_primary_keys_fail(self):
# kwarg filtering
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id=0).order_by('test_id')
TestModel.objects(test_id=0).order_by('test_id')
# kwarg filtering
with self.assertRaises(query.QueryException):
q = TestModel.objects(TestModel.test_id == 0).order_by('test_id')
TestModel.objects(TestModel.test_id == 0).order_by('test_id')
def test_ordering_by_non_primary_keys_fails(self):
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id=0).order_by('description')
TestModel.objects(test_id=0).order_by('description')
def test_ordering_on_indexed_columns_fails(self):
with self.assertRaises(query.QueryException):
q = IndexedTestModel.objects(test_id=0).order_by('attempt_id')
IndexedTestModel.objects(test_id=0).order_by('attempt_id')
def test_ordering_on_multiple_clustering_columns(self):
TestMultiClusteringModel.create(one=1, two=1, three=4)
@@ -673,7 +672,7 @@ class TestQuerySetDelete(BaseQuerySetUsage):
TestMultiClusteringModel.objects(one=1, two__gt=3, two__lt=5).delete()
self.assertEqual(5, len(TestMultiClusteringModel.objects.all()))
TestMultiClusteringModel.objects(one=1, two__in=[8,9]).delete()
TestMultiClusteringModel.objects(one=1, two__in=[8, 9]).delete()
self.assertEqual(3, len(TestMultiClusteringModel.objects.all()))
TestMultiClusteringModel.objects(one__in=[1], two__gte=0).delete()
@@ -878,7 +877,7 @@ class TestValuesList(BaseQuerySetUsage):
class TestObjectsProperty(BaseQuerySetUsage):
def test_objects_property_returns_fresh_queryset(self):
assert TestModel.objects._result_cache is None
len(TestModel.objects) # evaluate queryset
len(TestModel.objects) # evaluate queryset
assert TestModel.objects._result_cache is None

View File

@@ -74,7 +74,7 @@ class RoutingTests(unittest.TestCase):
select = s.prepare("SELECT token(%s) FROM %s WHERE %s" %
(primary_key, table_name, where_clause))
return (insert, select)
return insert, select
def test_singular_key(self):
# string