Merge pull request #537 from datastax/535
PYTHON-535 - set routing key in mapper queries
This commit is contained in:
@@ -19,7 +19,7 @@ import six
|
||||
from uuid import UUID as _UUID
|
||||
|
||||
from cassandra import util
|
||||
from cassandra.cqltypes import SimpleDateType
|
||||
from cassandra.cqltypes import SimpleDateType, _cqltypes, UserType
|
||||
from cassandra.cqlengine import ValidationError
|
||||
from cassandra.cqlengine.functions import get_total_seconds
|
||||
|
||||
@@ -159,6 +159,7 @@ class Column(object):
|
||||
|
||||
# the column name in the model definition
|
||||
self.column_name = None
|
||||
self._partition_key_index = None
|
||||
self.static = static
|
||||
|
||||
self.value = None
|
||||
@@ -255,6 +256,10 @@ class Column(object):
|
||||
def sub_types(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def cql_type(self):
|
||||
return _cqltypes[self.db_type]
|
||||
|
||||
|
||||
class Blob(Column):
|
||||
"""
|
||||
@@ -665,6 +670,10 @@ class BaseCollectionColumn(Column):
|
||||
def sub_types(self):
|
||||
return self.types
|
||||
|
||||
@property
|
||||
def cql_type(self):
|
||||
return _cqltypes[self.__class__.__name__.lower()].apply_parameters([c.cql_type for c in self.types])
|
||||
|
||||
|
||||
class Tuple(BaseCollectionColumn):
|
||||
"""
|
||||
@@ -876,6 +885,12 @@ class UserDefinedType(Column):
|
||||
def sub_types(self):
|
||||
return list(self.user_type._fields.values())
|
||||
|
||||
@property
|
||||
def cql_type(self):
|
||||
return UserType.make_udt_class(keyspace='', udt_name=self.user_type.type_name(),
|
||||
field_names=[c.db_field_name for c in self.user_type._fields.values()],
|
||||
field_types=[c.cql_type for c in self.user_type._fields.values()])
|
||||
|
||||
|
||||
def resolve_udts(col_def, out_list):
|
||||
for col in col_def.sub_types:
|
||||
|
||||
@@ -157,19 +157,16 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET):
|
||||
if not session:
|
||||
raise CQLEngineException("It is required to setup() cqlengine before executing queries")
|
||||
|
||||
if isinstance(query, Statement):
|
||||
pass
|
||||
|
||||
if isinstance(query, SimpleStatement):
|
||||
pass #
|
||||
elif isinstance(query, BaseCQLStatement):
|
||||
params = query.get_context()
|
||||
query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
|
||||
|
||||
elif isinstance(query, six.string_types):
|
||||
query = SimpleStatement(query, consistency_level=consistency_level)
|
||||
|
||||
log.debug(query.query_string)
|
||||
|
||||
params = params or {}
|
||||
result = session.execute(query, params, timeout=timeout)
|
||||
|
||||
return result
|
||||
|
||||
@@ -335,6 +335,8 @@ class BaseModel(object):
|
||||
|
||||
__options__ = None
|
||||
|
||||
__compute_routing_key__ = True
|
||||
|
||||
# the queryset class used for this class
|
||||
__queryset__ = query.ModelQuerySet
|
||||
__dmlquery__ = query.DMLQuery
|
||||
@@ -379,6 +381,10 @@ class BaseModel(object):
|
||||
return '{0} <{1}>'.format(self.__class__.__name__,
|
||||
', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys()))
|
||||
|
||||
@classmethod
|
||||
def _routing_key_from_values(cls, pk_values, protocol_version):
|
||||
return cls._key_serializer(pk_values, protocol_version)
|
||||
|
||||
@classmethod
|
||||
def _discover_polymorphic_submodels(cls):
|
||||
if not cls._is_polymorphic_base:
|
||||
@@ -843,6 +849,7 @@ class ModelMetaClass(type):
|
||||
|
||||
has_partition_keys = any(v.partition_key for (k, v) in column_definitions)
|
||||
|
||||
partition_key_index = 0
|
||||
# transform column definitions
|
||||
for k, v in column_definitions:
|
||||
# don't allow a column with the same name as a built-in attribute or method
|
||||
@@ -858,11 +865,23 @@ class ModelMetaClass(type):
|
||||
if not has_partition_keys and v.primary_key:
|
||||
v.partition_key = True
|
||||
has_partition_keys = True
|
||||
if v.partition_key:
|
||||
v._partition_key_index = partition_key_index
|
||||
partition_key_index += 1
|
||||
_transform_column(k, v)
|
||||
|
||||
partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key)
|
||||
clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key)
|
||||
|
||||
if attrs.get('__compute_routing_key__', True):
|
||||
key_cols = [c for c in partition_keys.values()]
|
||||
partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols)
|
||||
key_cql_types = [c.cql_type for c in key_cols]
|
||||
key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)])
|
||||
else:
|
||||
partition_key_index = {}
|
||||
key_serializer = staticmethod(lambda parts, proto_version: None)
|
||||
|
||||
# setup partition key shortcut
|
||||
if len(partition_keys) == 0:
|
||||
if not is_abstract:
|
||||
@@ -906,6 +925,8 @@ class ModelMetaClass(type):
|
||||
attrs['_dynamic_columns'] = {}
|
||||
|
||||
attrs['_partition_keys'] = partition_keys
|
||||
attrs['_partition_key_index'] = partition_key_index
|
||||
attrs['_key_serializer'] = key_serializer
|
||||
attrs['_clustering_keys'] = clustering_keys
|
||||
attrs['_has_counter'] = len(counter_columns) > 0
|
||||
|
||||
@@ -983,3 +1004,8 @@ class Model(BaseModel):
|
||||
"""
|
||||
*Optional* Specifies a value for the discriminator column when using model inheritance.
|
||||
"""
|
||||
|
||||
__compute_routing_key__ = True
|
||||
"""
|
||||
*Optional* Setting False disables computing the routing key for TokenAwareRouting
|
||||
"""
|
||||
|
||||
@@ -84,6 +84,8 @@ class NamedTable(object):
|
||||
|
||||
__partition_keys = None
|
||||
|
||||
_partition_key_index = None
|
||||
|
||||
class DoesNotExist(_DoesNotExist):
|
||||
pass
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import time
|
||||
import six
|
||||
from warnings import warn
|
||||
|
||||
from cassandra.query import SimpleStatement
|
||||
from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin
|
||||
from cassandra.cqlengine import connection
|
||||
from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue
|
||||
@@ -26,10 +27,8 @@ from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterTh
|
||||
GreaterThanOrEqualOperator, LessThanOperator,
|
||||
LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator)
|
||||
from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement,
|
||||
UpdateStatement, AssignmentClause, InsertStatement,
|
||||
BaseCQLStatement, MapUpdateClause, MapDeleteClause,
|
||||
ListUpdateClause, SetUpdateClause, CounterUpdateClause,
|
||||
ConditionalClause)
|
||||
UpdateStatement, InsertStatement,
|
||||
BaseCQLStatement, MapDeleteClause, ConditionalClause)
|
||||
|
||||
|
||||
class QueryException(CQLEngineException):
|
||||
@@ -39,6 +38,7 @@ class QueryException(CQLEngineException):
|
||||
class IfNotExistsWithCounterColumn(CQLEngineException):
|
||||
pass
|
||||
|
||||
|
||||
class IfExistsWithCounterColumn(CQLEngineException):
|
||||
pass
|
||||
|
||||
@@ -311,11 +311,11 @@ class AbstractQuerySet(object):
|
||||
def column_family_name(self):
|
||||
return self.model.column_family_name()
|
||||
|
||||
def _execute(self, q):
|
||||
def _execute(self, statement):
|
||||
if self._batch:
|
||||
return self._batch.add_query(q)
|
||||
return self._batch.add_query(statement)
|
||||
else:
|
||||
result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
|
||||
result = _execute_statement(self.model, statement, self._consistency, self._timeout)
|
||||
if self._if_not_exists or self._if_exists or self._conditional:
|
||||
check_applied(result)
|
||||
return result
|
||||
@@ -1209,14 +1209,14 @@ class DMLQuery(object):
|
||||
self._conditional = conditional
|
||||
self._timeout = timeout
|
||||
|
||||
def _execute(self, q):
|
||||
def _execute(self, statement):
|
||||
if self._batch:
|
||||
return self._batch.add_query(q)
|
||||
return self._batch.add_query(statement)
|
||||
else:
|
||||
tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
|
||||
results = _execute_statement(self.model, statement, self._consistency, self._timeout)
|
||||
if self._if_not_exists or self._if_exists or self._conditional:
|
||||
check_applied(tmp)
|
||||
return tmp
|
||||
check_applied(results)
|
||||
return results
|
||||
|
||||
def batch(self, batch_obj):
|
||||
if batch_obj is not None and not isinstance(batch_obj, BatchQuery):
|
||||
@@ -1243,11 +1243,7 @@ class DMLQuery(object):
|
||||
|
||||
if deleted_fields:
|
||||
for name, col in self.model._primary_keys.items():
|
||||
ds.add_where_clause(WhereClause(
|
||||
col.db_field_name,
|
||||
EqualsOperator(),
|
||||
col.to_database(getattr(self.instance, name))
|
||||
))
|
||||
ds.add_where(col, EqualsOperator(), getattr(self.instance, name))
|
||||
self._execute(ds)
|
||||
|
||||
def update(self):
|
||||
@@ -1289,11 +1285,7 @@ class DMLQuery(object):
|
||||
# only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error
|
||||
if (null_clustering_key or static_changed_only) and (not col.partition_key):
|
||||
continue
|
||||
statement.add_where_clause(WhereClause(
|
||||
col.db_field_name,
|
||||
EqualsOperator(),
|
||||
col.to_database(getattr(self.instance, name))
|
||||
))
|
||||
statement.add_where(col, EqualsOperator(), getattr(self.instance, name))
|
||||
self._execute(statement)
|
||||
|
||||
if not null_clustering_key:
|
||||
@@ -1328,10 +1320,7 @@ class DMLQuery(object):
|
||||
if self.instance._values[name].changed:
|
||||
nulled_fields.add(col.db_field_name)
|
||||
continue
|
||||
insert.add_assignment_clause(AssignmentClause(
|
||||
col.db_field_name,
|
||||
col.to_database(getattr(self.instance, name, None))
|
||||
))
|
||||
insert.add_assignment(col, getattr(self.instance, name, None))
|
||||
|
||||
# skip query execution if it's empty
|
||||
# caused by pointless update queries
|
||||
@@ -1348,12 +1337,20 @@ class DMLQuery(object):
|
||||
|
||||
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
|
||||
for name, col in self.model._primary_keys.items():
|
||||
if (not col.partition_key) and (getattr(self.instance, name) is None):
|
||||
val = getattr(self.instance, name)
|
||||
if val is None and not col.parition_key:
|
||||
continue
|
||||
|
||||
ds.add_where_clause(WhereClause(
|
||||
col.db_field_name,
|
||||
EqualsOperator(),
|
||||
col.to_database(getattr(self.instance, name))
|
||||
))
|
||||
ds.add_where(col, EqualsOperator(), val)
|
||||
self._execute(ds)
|
||||
|
||||
|
||||
def _execute_statement(model, statement, consistency_level, timeout):
|
||||
params = statement.get_context()
|
||||
s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=statement.fetch_size)
|
||||
if model._partition_key_index:
|
||||
key_values = statement.partition_key_values(model._partition_key_index)
|
||||
if not any(v is None for v in key_values):
|
||||
parts = model._routing_key_from_values(key_values, connection.get_cluster().protocol_version)
|
||||
s.routing_key = parts
|
||||
s.keyspace = model._get_keyspace()
|
||||
return connection.execute(s, params, timeout=timeout)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from itertools import ifilter
|
||||
import time
|
||||
import six
|
||||
|
||||
@@ -20,7 +21,7 @@ from cassandra.query import FETCH_SIZE_UNSET
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine import UnicodeMixin
|
||||
from cassandra.cqlengine.functions import QueryValue
|
||||
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator
|
||||
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator
|
||||
|
||||
|
||||
class StatementException(Exception):
|
||||
@@ -481,10 +482,9 @@ class MapDeleteClause(BaseDeleteClause):
|
||||
class BaseCQLStatement(UnicodeMixin):
|
||||
""" The base cql statement class """
|
||||
|
||||
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
|
||||
def __init__(self, table, timestamp=None, where=None, fetch_size=None, conditionals=None):
|
||||
super(BaseCQLStatement, self).__init__()
|
||||
self.table = table
|
||||
self.consistency = consistency
|
||||
self.context_id = 0
|
||||
self.context_counter = self.context_id
|
||||
self.timestamp = timestamp
|
||||
@@ -492,20 +492,27 @@ class BaseCQLStatement(UnicodeMixin):
|
||||
|
||||
self.where_clauses = []
|
||||
for clause in where or []:
|
||||
self.add_where_clause(clause)
|
||||
self._add_where_clause(clause)
|
||||
|
||||
self.conditionals = []
|
||||
for conditional in conditionals or []:
|
||||
self.add_conditional_clause(conditional)
|
||||
|
||||
def add_where_clause(self, clause):
|
||||
"""
|
||||
adds a where clause to this statement
|
||||
:param clause: the clause to add
|
||||
:type clause: WhereClause
|
||||
"""
|
||||
if not isinstance(clause, WhereClause):
|
||||
raise StatementException("only instances of WhereClause can be added to statements")
|
||||
def _update_part_key_values(self, field_index_map, clauses, parts):
|
||||
for clause in ifilter(lambda c: c.field in field_index_map, clauses):
|
||||
parts[field_index_map[clause.field]] = clause.value
|
||||
|
||||
def partition_key_values(self, field_index_map):
|
||||
parts = [None] * len(field_index_map)
|
||||
self._update_part_key_values(field_index_map, (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), parts)
|
||||
return parts
|
||||
|
||||
def add_where(self, column, operator, value, quote_field=True):
|
||||
value = column.to_database(value)
|
||||
clause = WhereClause(column.db_field_name, operator, value, quote_field)
|
||||
self._add_where_clause(clause)
|
||||
|
||||
def _add_where_clause(self, clause):
|
||||
clause.set_context_id(self.context_counter)
|
||||
self.context_counter += clause.get_context_size()
|
||||
self.where_clauses.append(clause)
|
||||
@@ -581,7 +588,6 @@ class SelectStatement(BaseCQLStatement):
|
||||
table,
|
||||
fields=None,
|
||||
count=False,
|
||||
consistency=None,
|
||||
where=None,
|
||||
order_by=None,
|
||||
limit=None,
|
||||
@@ -595,7 +601,6 @@ class SelectStatement(BaseCQLStatement):
|
||||
"""
|
||||
super(SelectStatement, self).__init__(
|
||||
table,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
fetch_size=fetch_size
|
||||
)
|
||||
@@ -641,14 +646,12 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
def __init__(self,
|
||||
table,
|
||||
assignments=None,
|
||||
consistency=None,
|
||||
where=None,
|
||||
ttl=None,
|
||||
timestamp=None,
|
||||
conditionals=None):
|
||||
super(AssignmentStatement, self).__init__(
|
||||
table,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
conditionals=conditionals
|
||||
)
|
||||
@@ -658,7 +661,7 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
# add assignments
|
||||
self.assignments = []
|
||||
for assignment in assignments or []:
|
||||
self.add_assignment_clause(assignment)
|
||||
self._add_assignment_clause(assignment)
|
||||
|
||||
def update_context_id(self, i):
|
||||
super(AssignmentStatement, self).update_context_id(i)
|
||||
@@ -666,14 +669,17 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
assignment.set_context_id(self.context_counter)
|
||||
self.context_counter += assignment.get_context_size()
|
||||
|
||||
def add_assignment_clause(self, clause):
|
||||
"""
|
||||
adds an assignment clause to this statement
|
||||
:param clause: the clause to add
|
||||
:type clause: AssignmentClause
|
||||
"""
|
||||
if not isinstance(clause, AssignmentClause):
|
||||
raise StatementException("only instances of AssignmentClause can be added to statements")
|
||||
def partition_key_values(self, field_index_map):
|
||||
parts = super(AssignmentStatement, self).partition_key_values(field_index_map)
|
||||
self._update_part_key_values(field_index_map, self.assignments, parts)
|
||||
return parts
|
||||
|
||||
def add_assignment(self, column, value):
|
||||
value = column.to_database(value)
|
||||
clause = AssignmentClause(column.db_field_name, value)
|
||||
self._add_assignment_clause(clause)
|
||||
|
||||
def _add_assignment_clause(self, clause):
|
||||
clause.set_context_id(self.context_counter)
|
||||
self.context_counter += clause.get_context_size()
|
||||
self.assignments.append(clause)
|
||||
@@ -695,23 +701,18 @@ class InsertStatement(AssignmentStatement):
|
||||
def __init__(self,
|
||||
table,
|
||||
assignments=None,
|
||||
consistency=None,
|
||||
where=None,
|
||||
ttl=None,
|
||||
timestamp=None,
|
||||
if_not_exists=False):
|
||||
super(InsertStatement, self).__init__(table,
|
||||
assignments=assignments,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
ttl=ttl,
|
||||
timestamp=timestamp)
|
||||
|
||||
self.if_not_exists = if_not_exists
|
||||
|
||||
def add_where_clause(self, clause):
|
||||
raise StatementException("Cannot add where clauses to insert statements")
|
||||
|
||||
def __unicode__(self):
|
||||
qs = ['INSERT INTO {0}'.format(self.table)]
|
||||
|
||||
@@ -741,7 +742,6 @@ class UpdateStatement(AssignmentStatement):
|
||||
def __init__(self,
|
||||
table,
|
||||
assignments=None,
|
||||
consistency=None,
|
||||
where=None,
|
||||
ttl=None,
|
||||
timestamp=None,
|
||||
@@ -749,7 +749,6 @@ class UpdateStatement(AssignmentStatement):
|
||||
if_exists=False):
|
||||
super(UpdateStatement, self). __init__(table,
|
||||
assignments=assignments,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
ttl=ttl,
|
||||
timestamp=timestamp,
|
||||
@@ -809,16 +808,15 @@ class UpdateStatement(AssignmentStatement):
|
||||
else:
|
||||
clause = AssignmentClause(column.db_field_name, value)
|
||||
if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates
|
||||
self.add_assignment_clause(clause)
|
||||
self._add_assignment_clause(clause)
|
||||
|
||||
|
||||
class DeleteStatement(BaseCQLStatement):
|
||||
""" a cql delete statement """
|
||||
|
||||
def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, conditionals=None, if_exists=False):
|
||||
def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=None, if_exists=False):
|
||||
super(DeleteStatement, self).__init__(
|
||||
table,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
timestamp=timestamp,
|
||||
conditionals=conditionals
|
||||
|
||||
@@ -77,6 +77,7 @@ def trim_if_startswith(s, prefix):
|
||||
|
||||
|
||||
_casstypes = {}
|
||||
_cqltypes = {}
|
||||
|
||||
|
||||
cql_type_scanner = re.Scanner((
|
||||
@@ -106,6 +107,8 @@ class CassandraTypeType(type):
|
||||
cls = type.__new__(metacls, name, bases, dct)
|
||||
if not name.startswith('_'):
|
||||
_casstypes[name] = cls
|
||||
if not cls.typename.startswith("'org"):
|
||||
_cqltypes[cls.typename] = cls
|
||||
return cls
|
||||
|
||||
|
||||
@@ -620,6 +623,11 @@ class SimpleDateType(_CassandraType):
|
||||
try:
|
||||
days = val.days_from_epoch
|
||||
except AttributeError:
|
||||
if isinstance(val, int):
|
||||
# the DB wants offset int values, but util.Date init takes days from epoch
|
||||
# here we assume int values are offset, as they would appear in CQL
|
||||
# short circuit to avoid subtracting just to add offset
|
||||
return uint32_pack(val)
|
||||
days = util.Date(val).days_from_epoch
|
||||
return uint32_pack(days + SimpleDateType.EPOCH_OFFSET_DAYS)
|
||||
|
||||
|
||||
@@ -233,13 +233,20 @@ class Statement(object):
|
||||
if custom_payload is not None:
|
||||
self.custom_payload = custom_payload
|
||||
|
||||
def _key_parts_packed(self, parts):
|
||||
for p in parts:
|
||||
l = len(p)
|
||||
yield struct.pack(">H%dsB" % l, l, p, 0)
|
||||
|
||||
def _get_routing_key(self):
|
||||
return self._routing_key
|
||||
|
||||
def _set_routing_key(self, key):
|
||||
if isinstance(key, (list, tuple)):
|
||||
self._routing_key = b"".join(struct.pack("HsB", len(component), component, 0)
|
||||
for component in key)
|
||||
if len(key) == 1:
|
||||
self._routing_key = key[0]
|
||||
else:
|
||||
self._routing_key = b"".join(self._key_parts_packed(key))
|
||||
else:
|
||||
self._routing_key = key
|
||||
|
||||
@@ -565,13 +572,7 @@ class BoundStatement(Statement):
|
||||
if len(routing_indexes) == 1:
|
||||
self._routing_key = self.values[routing_indexes[0]]
|
||||
else:
|
||||
components = []
|
||||
for statement_index in routing_indexes:
|
||||
val = self.values[statement_index]
|
||||
l = len(val)
|
||||
components.append(struct.pack(">H%dsB" % l, l, val, 0))
|
||||
|
||||
self._routing_key = b"".join(components)
|
||||
self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes))
|
||||
|
||||
return self._routing_key
|
||||
|
||||
|
||||
@@ -79,6 +79,8 @@ Model
|
||||
'tombstone_compaction_interval': '86400'},
|
||||
'gc_grace_seconds': '0'}
|
||||
|
||||
.. autoattribute:: __compute_routing_key__
|
||||
|
||||
|
||||
The base methods allow creating, storing, and querying modeled objects.
|
||||
|
||||
|
||||
@@ -61,7 +61,6 @@ class BaseColumnIOTest(BaseCassEngTestCase):
|
||||
|
||||
# create a table with the given column
|
||||
class IOTestModel(Model):
|
||||
table_name = cls.column.db_type + "_io_test_model_{0}".format(uuid4().hex[:8])
|
||||
pkey = cls.column(primary_key=True)
|
||||
data = cls.column()
|
||||
|
||||
|
||||
@@ -456,8 +456,8 @@ def test_non_quality_filtering():
|
||||
NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now())
|
||||
NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now())
|
||||
|
||||
qA = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering()
|
||||
num = qA.count()
|
||||
qa = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering()
|
||||
num = qa.count()
|
||||
assert num == 1, num
|
||||
|
||||
|
||||
@@ -472,7 +472,7 @@ class TestQuerySetDistinct(BaseQuerySetUsage):
|
||||
self.assertEqual(len(q), 3)
|
||||
|
||||
def test_distinct_with_filter(self):
|
||||
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1,2])
|
||||
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2])
|
||||
self.assertEqual(len(q), 2)
|
||||
|
||||
def test_distinct_with_non_partition(self):
|
||||
@@ -509,19 +509,19 @@ class TestQuerySetOrdering(BaseQuerySetUsage):
|
||||
def test_ordering_by_non_second_primary_keys_fail(self):
|
||||
# kwarg filtering
|
||||
with self.assertRaises(query.QueryException):
|
||||
q = TestModel.objects(test_id=0).order_by('test_id')
|
||||
TestModel.objects(test_id=0).order_by('test_id')
|
||||
|
||||
# kwarg filtering
|
||||
with self.assertRaises(query.QueryException):
|
||||
q = TestModel.objects(TestModel.test_id == 0).order_by('test_id')
|
||||
TestModel.objects(TestModel.test_id == 0).order_by('test_id')
|
||||
|
||||
def test_ordering_by_non_primary_keys_fails(self):
|
||||
with self.assertRaises(query.QueryException):
|
||||
q = TestModel.objects(test_id=0).order_by('description')
|
||||
TestModel.objects(test_id=0).order_by('description')
|
||||
|
||||
def test_ordering_on_indexed_columns_fails(self):
|
||||
with self.assertRaises(query.QueryException):
|
||||
q = IndexedTestModel.objects(test_id=0).order_by('attempt_id')
|
||||
IndexedTestModel.objects(test_id=0).order_by('attempt_id')
|
||||
|
||||
def test_ordering_on_multiple_clustering_columns(self):
|
||||
TestMultiClusteringModel.create(one=1, two=1, three=4)
|
||||
@@ -672,7 +672,7 @@ class TestQuerySetDelete(BaseQuerySetUsage):
|
||||
TestMultiClusteringModel.objects(one=1, two__gt=3, two__lt=5).delete()
|
||||
self.assertEqual(5, len(TestMultiClusteringModel.objects.all()))
|
||||
|
||||
TestMultiClusteringModel.objects(one=1, two__in=[8,9]).delete()
|
||||
TestMultiClusteringModel.objects(one=1, two__in=[8, 9]).delete()
|
||||
self.assertEqual(3, len(TestMultiClusteringModel.objects.all()))
|
||||
|
||||
TestMultiClusteringModel.objects(one__in=[1], two__gte=0).delete()
|
||||
@@ -877,7 +877,7 @@ class TestValuesList(BaseQuerySetUsage):
|
||||
class TestObjectsProperty(BaseQuerySetUsage):
|
||||
def test_objects_property_returns_fresh_queryset(self):
|
||||
assert TestModel.objects._result_cache is None
|
||||
len(TestModel.objects) # evaluate queryset
|
||||
len(TestModel.objects) # evaluate queryset
|
||||
assert TestModel.objects._result_cache is None
|
||||
|
||||
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# Copyright 2013-2016 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.cqlengine.statements import AssignmentStatement, StatementException
|
||||
|
||||
|
||||
class AssignmentStatementTest(unittest.TestCase):
|
||||
|
||||
def test_add_assignment_type_checking(self):
|
||||
""" tests that only assignment clauses can be added to queries """
|
||||
stmt = AssignmentStatement('table', [])
|
||||
with self.assertRaises(StatementException):
|
||||
stmt.add_assignment_clause('x=5')
|
||||
@@ -17,17 +17,11 @@ except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.query import FETCH_SIZE_UNSET
|
||||
from cassandra.cqlengine.statements import BaseCQLStatement, StatementException
|
||||
from cassandra.cqlengine.statements import BaseCQLStatement
|
||||
|
||||
|
||||
class BaseStatementTest(unittest.TestCase):
|
||||
|
||||
def test_where_clause_type_checking(self):
|
||||
""" tests that only assignment clauses can be added to queries """
|
||||
stmt = BaseCQLStatement('table', [])
|
||||
with self.assertRaises(StatementException):
|
||||
stmt.add_where_clause('x=5')
|
||||
|
||||
def test_fetch_size(self):
|
||||
""" tests that fetch_size is correctly set """
|
||||
stmt = BaseCQLStatement('table', None, fetch_size=1000)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from cassandra.cqlengine.columns import Column
|
||||
from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause
|
||||
from cassandra.cqlengine.operators import *
|
||||
import six
|
||||
@@ -45,13 +47,13 @@ class DeleteStatementTests(TestCase):
|
||||
|
||||
def test_where_clause_rendering(self):
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s', six.text_type(ds))
|
||||
|
||||
def test_context_update(self):
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4}))
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
|
||||
ds.update_context_id(7)
|
||||
self.assertEqual(six.text_type(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s')
|
||||
@@ -59,19 +61,19 @@ class DeleteStatementTests(TestCase):
|
||||
|
||||
def test_context(self):
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(ds.get_context(), {'0': 'b'})
|
||||
|
||||
def test_range_deletion_rendering(self):
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where_clause(WhereClause('created_at', GreaterThanOrEqualOperator(), '0'))
|
||||
ds.add_where_clause(WhereClause('created_at', LessThanOrEqualOperator(), '10'))
|
||||
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
ds.add_where(Column(db_field='created_at'), GreaterThanOrEqualOperator(), '0')
|
||||
ds.add_where(Column(db_field='created_at'), LessThanOrEqualOperator(), '10')
|
||||
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', six.text_type(ds))
|
||||
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where_clause(WhereClause('created_at', InOperator(), ['0', '10', '20']))
|
||||
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
ds.add_where(Column(db_field='created_at'), InOperator(), ['0', '10', '20'])
|
||||
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds))
|
||||
|
||||
def test_delete_conditional(self):
|
||||
|
||||
@@ -16,22 +16,18 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause
|
||||
|
||||
import six
|
||||
|
||||
from cassandra.cqlengine.columns import Column
|
||||
from cassandra.cqlengine.statements import InsertStatement
|
||||
|
||||
|
||||
class InsertStatementTests(unittest.TestCase):
|
||||
|
||||
def test_where_clause_failure(self):
|
||||
""" tests that where clauses cannot be added to Insert statements """
|
||||
ist = InsertStatement('table', None)
|
||||
with self.assertRaises(StatementException):
|
||||
ist.add_where_clause('s')
|
||||
|
||||
def test_statement(self):
|
||||
ist = InsertStatement('table', None)
|
||||
ist.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
ist.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
ist.add_assignment(Column(db_field='a'), 'b')
|
||||
ist.add_assignment(Column(db_field='c'), 'd')
|
||||
|
||||
self.assertEqual(
|
||||
six.text_type(ist),
|
||||
@@ -40,8 +36,8 @@ class InsertStatementTests(unittest.TestCase):
|
||||
|
||||
def test_context_update(self):
|
||||
ist = InsertStatement('table', None)
|
||||
ist.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
ist.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
ist.add_assignment(Column(db_field='a'), 'b')
|
||||
ist.add_assignment(Column(db_field='c'), 'd')
|
||||
|
||||
ist.update_context_id(4)
|
||||
self.assertEqual(
|
||||
@@ -53,6 +49,6 @@ class InsertStatementTests(unittest.TestCase):
|
||||
|
||||
def test_additional_rendering(self):
|
||||
ist = InsertStatement('table', ttl=60)
|
||||
ist.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
ist.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
ist.add_assignment(Column(db_field='a'), 'b')
|
||||
ist.add_assignment(Column(db_field='c'), 'd')
|
||||
self.assertIn('USING TTL 60', six.text_type(ist))
|
||||
|
||||
@@ -16,6 +16,7 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.cqlengine.columns import Column
|
||||
from cassandra.cqlengine.statements import SelectStatement, WhereClause
|
||||
from cassandra.cqlengine.operators import *
|
||||
import six
|
||||
@@ -46,19 +47,19 @@ class SelectStatementTests(unittest.TestCase):
|
||||
|
||||
def test_where_clause_rendering(self):
|
||||
ss = SelectStatement('table')
|
||||
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(six.text_type(ss), 'SELECT * FROM table WHERE "a" = %(0)s', six.text_type(ss))
|
||||
|
||||
def test_count(self):
|
||||
ss = SelectStatement('table', count=True, limit=10, order_by='d')
|
||||
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(six.text_type(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', six.text_type(ss))
|
||||
self.assertIn('LIMIT', six.text_type(ss))
|
||||
self.assertNotIn('ORDER', six.text_type(ss))
|
||||
|
||||
def test_distinct(self):
|
||||
ss = SelectStatement('table', distinct_fields=['field2'])
|
||||
ss.add_where_clause(WhereClause('field1', EqualsOperator(), 'b'))
|
||||
ss.add_where(Column(db_field='field1'), EqualsOperator(), 'b')
|
||||
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', six.text_type(ss))
|
||||
|
||||
ss = SelectStatement('table', distinct_fields=['field1', 'field2'])
|
||||
@@ -69,13 +70,13 @@ class SelectStatementTests(unittest.TestCase):
|
||||
|
||||
def test_context(self):
|
||||
ss = SelectStatement('table')
|
||||
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(ss.get_context(), {'0': 'b'})
|
||||
|
||||
def test_context_id_update(self):
|
||||
""" tests that the right things happen the the context id """
|
||||
ss = SelectStatement('table')
|
||||
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(ss.get_context(), {'0': 'b'})
|
||||
self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s')
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.cqlengine.columns import Column, Set, List, Text
|
||||
from cassandra.cqlengine.operators import *
|
||||
from cassandra.cqlengine.statements import (UpdateStatement, WhereClause,
|
||||
AssignmentClause, SetUpdateClause,
|
||||
@@ -33,54 +34,54 @@ class UpdateStatementTests(unittest.TestCase):
|
||||
|
||||
def test_rendering(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
us.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
|
||||
us.add_assignment(Column(db_field='a'), 'b')
|
||||
us.add_assignment(Column(db_field='c'), 'd')
|
||||
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', six.text_type(us))
|
||||
|
||||
def test_context(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
us.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
|
||||
us.add_assignment(Column(db_field='a'), 'b')
|
||||
us.add_assignment(Column(db_field='c'), 'd')
|
||||
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
|
||||
self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'})
|
||||
|
||||
def test_context_update(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
us.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
|
||||
us.add_assignment(Column(db_field='a'), 'b')
|
||||
us.add_assignment(Column(db_field='c'), 'd')
|
||||
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
|
||||
us.update_context_id(3)
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s')
|
||||
self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'})
|
||||
|
||||
def test_additional_rendering(self):
|
||||
us = UpdateStatement('table', ttl=60)
|
||||
us.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
|
||||
us.add_assignment(Column(db_field='a'), 'b')
|
||||
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
|
||||
self.assertIn('USING TTL 60', six.text_type(us))
|
||||
|
||||
def test_update_set_add(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(SetUpdateClause('a', set((1,)), operation='add'))
|
||||
us.add_update(Set(Text, db_field='a'), set((1,)), 'add')
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
|
||||
|
||||
def test_update_empty_set_add_does_not_assign(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(SetUpdateClause('a', set(), operation='add'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
|
||||
us.add_update(Set(Text, db_field='a'), set(), 'add')
|
||||
self.assertFalse(us.assignments)
|
||||
|
||||
def test_update_empty_set_removal_does_not_assign(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(SetUpdateClause('a', set(), operation='remove'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" - %(0)s')
|
||||
us.add_update(Set(Text, db_field='a'), set(), 'remove')
|
||||
self.assertFalse(us.assignments)
|
||||
|
||||
def test_update_list_prepend_with_empty_list(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(ListUpdateClause('a', [], operation='prepend'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s + "a"')
|
||||
us.add_update(List(Text, db_field='a'), [], 'prepend')
|
||||
self.assertFalse(us.assignments)
|
||||
|
||||
def test_update_list_append_with_empty_list(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(ListUpdateClause('a', [], operation='append'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
|
||||
us.add_update(List(Text, db_field='a'), [], 'append')
|
||||
self.assertFalse(us.assignments)
|
||||
|
||||
@@ -74,7 +74,7 @@ class RoutingTests(unittest.TestCase):
|
||||
select = s.prepare("SELECT token(%s) FROM %s WHERE %s" %
|
||||
(primary_key, table_name, where_clause))
|
||||
|
||||
return (insert, select)
|
||||
return insert, select
|
||||
|
||||
def test_singular_key(self):
|
||||
# string
|
||||
|
||||
Reference in New Issue
Block a user