Merge pull request #537 from datastax/535

PYTHON-535 - set routing key in mapper queries
This commit is contained in:
Adam Holmberg
2016-04-05 10:59:24 -05:00
18 changed files with 187 additions and 176 deletions

View File

@@ -19,7 +19,7 @@ import six
from uuid import UUID as _UUID
from cassandra import util
from cassandra.cqltypes import SimpleDateType
from cassandra.cqltypes import SimpleDateType, _cqltypes, UserType
from cassandra.cqlengine import ValidationError
from cassandra.cqlengine.functions import get_total_seconds
@@ -159,6 +159,7 @@ class Column(object):
# the column name in the model definition
self.column_name = None
self._partition_key_index = None
self.static = static
self.value = None
@@ -255,6 +256,10 @@ class Column(object):
def sub_types(self):
return []
@property
def cql_type(self):
return _cqltypes[self.db_type]
class Blob(Column):
"""
@@ -665,6 +670,10 @@ class BaseCollectionColumn(Column):
def sub_types(self):
return self.types
@property
def cql_type(self):
return _cqltypes[self.__class__.__name__.lower()].apply_parameters([c.cql_type for c in self.types])
class Tuple(BaseCollectionColumn):
"""
@@ -876,6 +885,12 @@ class UserDefinedType(Column):
def sub_types(self):
return list(self.user_type._fields.values())
@property
def cql_type(self):
return UserType.make_udt_class(keyspace='', udt_name=self.user_type.type_name(),
field_names=[c.db_field_name for c in self.user_type._fields.values()],
field_types=[c.cql_type for c in self.user_type._fields.values()])
def resolve_udts(col_def, out_list):
for col in col_def.sub_types:

View File

@@ -157,19 +157,16 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET):
if not session:
raise CQLEngineException("It is required to setup() cqlengine before executing queries")
if isinstance(query, Statement):
pass
if isinstance(query, SimpleStatement):
pass #
elif isinstance(query, BaseCQLStatement):
params = query.get_context()
query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
elif isinstance(query, six.string_types):
query = SimpleStatement(query, consistency_level=consistency_level)
log.debug(query.query_string)
params = params or {}
result = session.execute(query, params, timeout=timeout)
return result

View File

@@ -335,6 +335,8 @@ class BaseModel(object):
__options__ = None
__compute_routing_key__ = True
# the queryset class used for this class
__queryset__ = query.ModelQuerySet
__dmlquery__ = query.DMLQuery
@@ -379,6 +381,10 @@ class BaseModel(object):
return '{0} <{1}>'.format(self.__class__.__name__,
', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys()))
@classmethod
def _routing_key_from_values(cls, pk_values, protocol_version):
return cls._key_serializer(pk_values, protocol_version)
@classmethod
def _discover_polymorphic_submodels(cls):
if not cls._is_polymorphic_base:
@@ -843,6 +849,7 @@ class ModelMetaClass(type):
has_partition_keys = any(v.partition_key for (k, v) in column_definitions)
partition_key_index = 0
# transform column definitions
for k, v in column_definitions:
# don't allow a column with the same name as a built-in attribute or method
@@ -858,11 +865,23 @@ class ModelMetaClass(type):
if not has_partition_keys and v.primary_key:
v.partition_key = True
has_partition_keys = True
if v.partition_key:
v._partition_key_index = partition_key_index
partition_key_index += 1
_transform_column(k, v)
partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key)
clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key)
if attrs.get('__compute_routing_key__', True):
key_cols = [c for c in partition_keys.values()]
partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols)
key_cql_types = [c.cql_type for c in key_cols]
key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)])
else:
partition_key_index = {}
key_serializer = staticmethod(lambda parts, proto_version: None)
# setup partition key shortcut
if len(partition_keys) == 0:
if not is_abstract:
@@ -906,6 +925,8 @@ class ModelMetaClass(type):
attrs['_dynamic_columns'] = {}
attrs['_partition_keys'] = partition_keys
attrs['_partition_key_index'] = partition_key_index
attrs['_key_serializer'] = key_serializer
attrs['_clustering_keys'] = clustering_keys
attrs['_has_counter'] = len(counter_columns) > 0
@@ -983,3 +1004,8 @@ class Model(BaseModel):
"""
*Optional* Specifies a value for the discriminator column when using model inheritance.
"""
__compute_routing_key__ = True
"""
*Optional* Setting False disables computing the routing key for TokenAwareRouting
"""

View File

@@ -84,6 +84,8 @@ class NamedTable(object):
__partition_keys = None
_partition_key_index = None
class DoesNotExist(_DoesNotExist):
pass

View File

@@ -19,6 +19,7 @@ import time
import six
from warnings import warn
from cassandra.query import SimpleStatement
from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin
from cassandra.cqlengine import connection
from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue
@@ -26,10 +27,8 @@ from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterTh
GreaterThanOrEqualOperator, LessThanOperator,
LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator)
from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement,
UpdateStatement, AssignmentClause, InsertStatement,
BaseCQLStatement, MapUpdateClause, MapDeleteClause,
ListUpdateClause, SetUpdateClause, CounterUpdateClause,
ConditionalClause)
UpdateStatement, InsertStatement,
BaseCQLStatement, MapDeleteClause, ConditionalClause)
class QueryException(CQLEngineException):
@@ -39,6 +38,7 @@ class QueryException(CQLEngineException):
class IfNotExistsWithCounterColumn(CQLEngineException):
pass
class IfExistsWithCounterColumn(CQLEngineException):
pass
@@ -311,11 +311,11 @@ class AbstractQuerySet(object):
def column_family_name(self):
return self.model.column_family_name()
def _execute(self, q):
def _execute(self, statement):
if self._batch:
return self._batch.add_query(q)
return self._batch.add_query(statement)
else:
result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
result = _execute_statement(self.model, statement, self._consistency, self._timeout)
if self._if_not_exists or self._if_exists or self._conditional:
check_applied(result)
return result
@@ -1209,14 +1209,14 @@ class DMLQuery(object):
self._conditional = conditional
self._timeout = timeout
def _execute(self, q):
def _execute(self, statement):
if self._batch:
return self._batch.add_query(q)
return self._batch.add_query(statement)
else:
tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
results = _execute_statement(self.model, statement, self._consistency, self._timeout)
if self._if_not_exists or self._if_exists or self._conditional:
check_applied(tmp)
return tmp
check_applied(results)
return results
def batch(self, batch_obj):
if batch_obj is not None and not isinstance(batch_obj, BatchQuery):
@@ -1243,11 +1243,7 @@ class DMLQuery(object):
if deleted_fields:
for name, col in self.model._primary_keys.items():
ds.add_where_clause(WhereClause(
col.db_field_name,
EqualsOperator(),
col.to_database(getattr(self.instance, name))
))
ds.add_where(col, EqualsOperator(), getattr(self.instance, name))
self._execute(ds)
def update(self):
@@ -1289,11 +1285,7 @@ class DMLQuery(object):
# only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error
if (null_clustering_key or static_changed_only) and (not col.partition_key):
continue
statement.add_where_clause(WhereClause(
col.db_field_name,
EqualsOperator(),
col.to_database(getattr(self.instance, name))
))
statement.add_where(col, EqualsOperator(), getattr(self.instance, name))
self._execute(statement)
if not null_clustering_key:
@@ -1328,10 +1320,7 @@ class DMLQuery(object):
if self.instance._values[name].changed:
nulled_fields.add(col.db_field_name)
continue
insert.add_assignment_clause(AssignmentClause(
col.db_field_name,
col.to_database(getattr(self.instance, name, None))
))
insert.add_assignment(col, getattr(self.instance, name, None))
# skip query execution if it's empty
# caused by pointless update queries
@@ -1348,12 +1337,20 @@ class DMLQuery(object):
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
for name, col in self.model._primary_keys.items():
if (not col.partition_key) and (getattr(self.instance, name) is None):
val = getattr(self.instance, name)
if val is None and not col.parition_key:
continue
ds.add_where_clause(WhereClause(
col.db_field_name,
EqualsOperator(),
col.to_database(getattr(self.instance, name))
))
ds.add_where(col, EqualsOperator(), val)
self._execute(ds)
def _execute_statement(model, statement, consistency_level, timeout):
params = statement.get_context()
s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=statement.fetch_size)
if model._partition_key_index:
key_values = statement.partition_key_values(model._partition_key_index)
if not any(v is None for v in key_values):
parts = model._routing_key_from_values(key_values, connection.get_cluster().protocol_version)
s.routing_key = parts
s.keyspace = model._get_keyspace()
return connection.execute(s, params, timeout=timeout)

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from datetime import datetime, timedelta
from itertools import ifilter
import time
import six
@@ -20,7 +21,7 @@ from cassandra.query import FETCH_SIZE_UNSET
from cassandra.cqlengine import columns
from cassandra.cqlengine import UnicodeMixin
from cassandra.cqlengine.functions import QueryValue
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator
class StatementException(Exception):
@@ -481,10 +482,9 @@ class MapDeleteClause(BaseDeleteClause):
class BaseCQLStatement(UnicodeMixin):
""" The base cql statement class """
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
def __init__(self, table, timestamp=None, where=None, fetch_size=None, conditionals=None):
super(BaseCQLStatement, self).__init__()
self.table = table
self.consistency = consistency
self.context_id = 0
self.context_counter = self.context_id
self.timestamp = timestamp
@@ -492,20 +492,27 @@ class BaseCQLStatement(UnicodeMixin):
self.where_clauses = []
for clause in where or []:
self.add_where_clause(clause)
self._add_where_clause(clause)
self.conditionals = []
for conditional in conditionals or []:
self.add_conditional_clause(conditional)
def add_where_clause(self, clause):
"""
adds a where clause to this statement
:param clause: the clause to add
:type clause: WhereClause
"""
if not isinstance(clause, WhereClause):
raise StatementException("only instances of WhereClause can be added to statements")
def _update_part_key_values(self, field_index_map, clauses, parts):
for clause in ifilter(lambda c: c.field in field_index_map, clauses):
parts[field_index_map[clause.field]] = clause.value
def partition_key_values(self, field_index_map):
parts = [None] * len(field_index_map)
self._update_part_key_values(field_index_map, (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), parts)
return parts
def add_where(self, column, operator, value, quote_field=True):
value = column.to_database(value)
clause = WhereClause(column.db_field_name, operator, value, quote_field)
self._add_where_clause(clause)
def _add_where_clause(self, clause):
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.where_clauses.append(clause)
@@ -581,7 +588,6 @@ class SelectStatement(BaseCQLStatement):
table,
fields=None,
count=False,
consistency=None,
where=None,
order_by=None,
limit=None,
@@ -595,7 +601,6 @@ class SelectStatement(BaseCQLStatement):
"""
super(SelectStatement, self).__init__(
table,
consistency=consistency,
where=where,
fetch_size=fetch_size
)
@@ -641,14 +646,12 @@ class AssignmentStatement(BaseCQLStatement):
def __init__(self,
table,
assignments=None,
consistency=None,
where=None,
ttl=None,
timestamp=None,
conditionals=None):
super(AssignmentStatement, self).__init__(
table,
consistency=consistency,
where=where,
conditionals=conditionals
)
@@ -658,7 +661,7 @@ class AssignmentStatement(BaseCQLStatement):
# add assignments
self.assignments = []
for assignment in assignments or []:
self.add_assignment_clause(assignment)
self._add_assignment_clause(assignment)
def update_context_id(self, i):
super(AssignmentStatement, self).update_context_id(i)
@@ -666,14 +669,17 @@ class AssignmentStatement(BaseCQLStatement):
assignment.set_context_id(self.context_counter)
self.context_counter += assignment.get_context_size()
def add_assignment_clause(self, clause):
"""
adds an assignment clause to this statement
:param clause: the clause to add
:type clause: AssignmentClause
"""
if not isinstance(clause, AssignmentClause):
raise StatementException("only instances of AssignmentClause can be added to statements")
def partition_key_values(self, field_index_map):
parts = super(AssignmentStatement, self).partition_key_values(field_index_map)
self._update_part_key_values(field_index_map, self.assignments, parts)
return parts
def add_assignment(self, column, value):
value = column.to_database(value)
clause = AssignmentClause(column.db_field_name, value)
self._add_assignment_clause(clause)
def _add_assignment_clause(self, clause):
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.assignments.append(clause)
@@ -695,23 +701,18 @@ class InsertStatement(AssignmentStatement):
def __init__(self,
table,
assignments=None,
consistency=None,
where=None,
ttl=None,
timestamp=None,
if_not_exists=False):
super(InsertStatement, self).__init__(table,
assignments=assignments,
consistency=consistency,
where=where,
ttl=ttl,
timestamp=timestamp)
self.if_not_exists = if_not_exists
def add_where_clause(self, clause):
raise StatementException("Cannot add where clauses to insert statements")
def __unicode__(self):
qs = ['INSERT INTO {0}'.format(self.table)]
@@ -741,7 +742,6 @@ class UpdateStatement(AssignmentStatement):
def __init__(self,
table,
assignments=None,
consistency=None,
where=None,
ttl=None,
timestamp=None,
@@ -749,7 +749,6 @@ class UpdateStatement(AssignmentStatement):
if_exists=False):
super(UpdateStatement, self). __init__(table,
assignments=assignments,
consistency=consistency,
where=where,
ttl=ttl,
timestamp=timestamp,
@@ -809,16 +808,15 @@ class UpdateStatement(AssignmentStatement):
else:
clause = AssignmentClause(column.db_field_name, value)
if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates
self.add_assignment_clause(clause)
self._add_assignment_clause(clause)
class DeleteStatement(BaseCQLStatement):
""" a cql delete statement """
def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, conditionals=None, if_exists=False):
def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=None, if_exists=False):
super(DeleteStatement, self).__init__(
table,
consistency=consistency,
where=where,
timestamp=timestamp,
conditionals=conditionals

View File

@@ -77,6 +77,7 @@ def trim_if_startswith(s, prefix):
_casstypes = {}
_cqltypes = {}
cql_type_scanner = re.Scanner((
@@ -106,6 +107,8 @@ class CassandraTypeType(type):
cls = type.__new__(metacls, name, bases, dct)
if not name.startswith('_'):
_casstypes[name] = cls
if not cls.typename.startswith("'org"):
_cqltypes[cls.typename] = cls
return cls
@@ -620,6 +623,11 @@ class SimpleDateType(_CassandraType):
try:
days = val.days_from_epoch
except AttributeError:
if isinstance(val, int):
# the DB wants offset int values, but util.Date init takes days from epoch
# here we assume int values are offset, as they would appear in CQL
# short circuit to avoid subtracting just to add offset
return uint32_pack(val)
days = util.Date(val).days_from_epoch
return uint32_pack(days + SimpleDateType.EPOCH_OFFSET_DAYS)

View File

@@ -233,13 +233,20 @@ class Statement(object):
if custom_payload is not None:
self.custom_payload = custom_payload
def _key_parts_packed(self, parts):
for p in parts:
l = len(p)
yield struct.pack(">H%dsB" % l, l, p, 0)
def _get_routing_key(self):
return self._routing_key
def _set_routing_key(self, key):
if isinstance(key, (list, tuple)):
self._routing_key = b"".join(struct.pack("HsB", len(component), component, 0)
for component in key)
if len(key) == 1:
self._routing_key = key[0]
else:
self._routing_key = b"".join(self._key_parts_packed(key))
else:
self._routing_key = key
@@ -565,13 +572,7 @@ class BoundStatement(Statement):
if len(routing_indexes) == 1:
self._routing_key = self.values[routing_indexes[0]]
else:
components = []
for statement_index in routing_indexes:
val = self.values[statement_index]
l = len(val)
components.append(struct.pack(">H%dsB" % l, l, val, 0))
self._routing_key = b"".join(components)
self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes))
return self._routing_key

View File

@@ -79,6 +79,8 @@ Model
'tombstone_compaction_interval': '86400'},
'gc_grace_seconds': '0'}
.. autoattribute:: __compute_routing_key__
The base methods allow creating, storing, and querying modeled objects.

View File

@@ -61,7 +61,6 @@ class BaseColumnIOTest(BaseCassEngTestCase):
# create a table with the given column
class IOTestModel(Model):
table_name = cls.column.db_type + "_io_test_model_{0}".format(uuid4().hex[:8])
pkey = cls.column(primary_key=True)
data = cls.column()

View File

@@ -456,8 +456,8 @@ def test_non_quality_filtering():
NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now())
NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now())
qA = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering()
num = qA.count()
qa = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering()
num = qa.count()
assert num == 1, num
@@ -472,7 +472,7 @@ class TestQuerySetDistinct(BaseQuerySetUsage):
self.assertEqual(len(q), 3)
def test_distinct_with_filter(self):
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1,2])
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2])
self.assertEqual(len(q), 2)
def test_distinct_with_non_partition(self):
@@ -509,19 +509,19 @@ class TestQuerySetOrdering(BaseQuerySetUsage):
def test_ordering_by_non_second_primary_keys_fail(self):
# kwarg filtering
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id=0).order_by('test_id')
TestModel.objects(test_id=0).order_by('test_id')
# kwarg filtering
with self.assertRaises(query.QueryException):
q = TestModel.objects(TestModel.test_id == 0).order_by('test_id')
TestModel.objects(TestModel.test_id == 0).order_by('test_id')
def test_ordering_by_non_primary_keys_fails(self):
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id=0).order_by('description')
TestModel.objects(test_id=0).order_by('description')
def test_ordering_on_indexed_columns_fails(self):
with self.assertRaises(query.QueryException):
q = IndexedTestModel.objects(test_id=0).order_by('attempt_id')
IndexedTestModel.objects(test_id=0).order_by('attempt_id')
def test_ordering_on_multiple_clustering_columns(self):
TestMultiClusteringModel.create(one=1, two=1, three=4)
@@ -672,7 +672,7 @@ class TestQuerySetDelete(BaseQuerySetUsage):
TestMultiClusteringModel.objects(one=1, two__gt=3, two__lt=5).delete()
self.assertEqual(5, len(TestMultiClusteringModel.objects.all()))
TestMultiClusteringModel.objects(one=1, two__in=[8,9]).delete()
TestMultiClusteringModel.objects(one=1, two__in=[8, 9]).delete()
self.assertEqual(3, len(TestMultiClusteringModel.objects.all()))
TestMultiClusteringModel.objects(one__in=[1], two__gte=0).delete()
@@ -877,7 +877,7 @@ class TestValuesList(BaseQuerySetUsage):
class TestObjectsProperty(BaseQuerySetUsage):
def test_objects_property_returns_fresh_queryset(self):
assert TestModel.objects._result_cache is None
len(TestModel.objects) # evaluate queryset
len(TestModel.objects) # evaluate queryset
assert TestModel.objects._result_cache is None

View File

@@ -1,28 +0,0 @@
# Copyright 2013-2016 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from cassandra.cqlengine.statements import AssignmentStatement, StatementException
class AssignmentStatementTest(unittest.TestCase):
def test_add_assignment_type_checking(self):
""" tests that only assignment clauses can be added to queries """
stmt = AssignmentStatement('table', [])
with self.assertRaises(StatementException):
stmt.add_assignment_clause('x=5')

View File

@@ -17,17 +17,11 @@ except ImportError:
import unittest # noqa
from cassandra.query import FETCH_SIZE_UNSET
from cassandra.cqlengine.statements import BaseCQLStatement, StatementException
from cassandra.cqlengine.statements import BaseCQLStatement
class BaseStatementTest(unittest.TestCase):
def test_where_clause_type_checking(self):
""" tests that only assignment clauses can be added to queries """
stmt = BaseCQLStatement('table', [])
with self.assertRaises(StatementException):
stmt.add_where_clause('x=5')
def test_fetch_size(self):
""" tests that fetch_size is correctly set """
stmt = BaseCQLStatement('table', None, fetch_size=1000)

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from unittest import TestCase
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause
from cassandra.cqlengine.operators import *
import six
@@ -45,13 +47,13 @@ class DeleteStatementTests(TestCase):
def test_where_clause_rendering(self):
ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s', six.text_type(ds))
def test_context_update(self):
ds = DeleteStatement('table', None)
ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4}))
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
ds.update_context_id(7)
self.assertEqual(six.text_type(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s')
@@ -59,19 +61,19 @@ class DeleteStatementTests(TestCase):
def test_context(self):
ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(ds.get_context(), {'0': 'b'})
def test_range_deletion_rendering(self):
ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where_clause(WhereClause('created_at', GreaterThanOrEqualOperator(), '0'))
ds.add_where_clause(WhereClause('created_at', LessThanOrEqualOperator(), '10'))
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
ds.add_where(Column(db_field='created_at'), GreaterThanOrEqualOperator(), '0')
ds.add_where(Column(db_field='created_at'), LessThanOrEqualOperator(), '10')
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', six.text_type(ds))
ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where_clause(WhereClause('created_at', InOperator(), ['0', '10', '20']))
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
ds.add_where(Column(db_field='created_at'), InOperator(), ['0', '10', '20'])
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds))
def test_delete_conditional(self):

View File

@@ -16,22 +16,18 @@ try:
except ImportError:
import unittest # noqa
from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause
import six
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import InsertStatement
class InsertStatementTests(unittest.TestCase):
def test_where_clause_failure(self):
""" tests that where clauses cannot be added to Insert statements """
ist = InsertStatement('table', None)
with self.assertRaises(StatementException):
ist.add_where_clause('s')
def test_statement(self):
ist = InsertStatement('table', None)
ist.add_assignment_clause(AssignmentClause('a', 'b'))
ist.add_assignment_clause(AssignmentClause('c', 'd'))
ist.add_assignment(Column(db_field='a'), 'b')
ist.add_assignment(Column(db_field='c'), 'd')
self.assertEqual(
six.text_type(ist),
@@ -40,8 +36,8 @@ class InsertStatementTests(unittest.TestCase):
def test_context_update(self):
ist = InsertStatement('table', None)
ist.add_assignment_clause(AssignmentClause('a', 'b'))
ist.add_assignment_clause(AssignmentClause('c', 'd'))
ist.add_assignment(Column(db_field='a'), 'b')
ist.add_assignment(Column(db_field='c'), 'd')
ist.update_context_id(4)
self.assertEqual(
@@ -53,6 +49,6 @@ class InsertStatementTests(unittest.TestCase):
def test_additional_rendering(self):
ist = InsertStatement('table', ttl=60)
ist.add_assignment_clause(AssignmentClause('a', 'b'))
ist.add_assignment_clause(AssignmentClause('c', 'd'))
ist.add_assignment(Column(db_field='a'), 'b')
ist.add_assignment(Column(db_field='c'), 'd')
self.assertIn('USING TTL 60', six.text_type(ist))

View File

@@ -16,6 +16,7 @@ try:
except ImportError:
import unittest # noqa
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import SelectStatement, WhereClause
from cassandra.cqlengine.operators import *
import six
@@ -46,19 +47,19 @@ class SelectStatementTests(unittest.TestCase):
def test_where_clause_rendering(self):
ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ss), 'SELECT * FROM table WHERE "a" = %(0)s', six.text_type(ss))
def test_count(self):
ss = SelectStatement('table', count=True, limit=10, order_by='d')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', six.text_type(ss))
self.assertIn('LIMIT', six.text_type(ss))
self.assertNotIn('ORDER', six.text_type(ss))
def test_distinct(self):
ss = SelectStatement('table', distinct_fields=['field2'])
ss.add_where_clause(WhereClause('field1', EqualsOperator(), 'b'))
ss.add_where(Column(db_field='field1'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', six.text_type(ss))
ss = SelectStatement('table', distinct_fields=['field1', 'field2'])
@@ -69,13 +70,13 @@ class SelectStatementTests(unittest.TestCase):
def test_context(self):
ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(ss.get_context(), {'0': 'b'})
def test_context_id_update(self):
""" tests that the right things happen the the context id """
ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(ss.get_context(), {'0': 'b'})
self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s')

View File

@@ -16,6 +16,7 @@ try:
except ImportError:
import unittest # noqa
from cassandra.cqlengine.columns import Column, Set, List, Text
from cassandra.cqlengine.operators import *
from cassandra.cqlengine.statements import (UpdateStatement, WhereClause,
AssignmentClause, SetUpdateClause,
@@ -33,54 +34,54 @@ class UpdateStatementTests(unittest.TestCase):
def test_rendering(self):
us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_assignment_clause(AssignmentClause('c', 'd'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
us.add_assignment(Column(db_field='a'), 'b')
us.add_assignment(Column(db_field='c'), 'd')
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', six.text_type(us))
def test_context(self):
us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_assignment_clause(AssignmentClause('c', 'd'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
us.add_assignment(Column(db_field='a'), 'b')
us.add_assignment(Column(db_field='c'), 'd')
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'})
def test_context_update(self):
us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_assignment_clause(AssignmentClause('c', 'd'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
us.add_assignment(Column(db_field='a'), 'b')
us.add_assignment(Column(db_field='c'), 'd')
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
us.update_context_id(3)
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s')
self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'})
def test_additional_rendering(self):
us = UpdateStatement('table', ttl=60)
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
us.add_assignment(Column(db_field='a'), 'b')
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
self.assertIn('USING TTL 60', six.text_type(us))
def test_update_set_add(self):
us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', set((1,)), operation='add'))
us.add_update(Set(Text, db_field='a'), set((1,)), 'add')
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
def test_update_empty_set_add_does_not_assign(self):
us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', set(), operation='add'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
us.add_update(Set(Text, db_field='a'), set(), 'add')
self.assertFalse(us.assignments)
def test_update_empty_set_removal_does_not_assign(self):
us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', set(), operation='remove'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" - %(0)s')
us.add_update(Set(Text, db_field='a'), set(), 'remove')
self.assertFalse(us.assignments)
def test_update_list_prepend_with_empty_list(self):
us = UpdateStatement('table')
us.add_assignment_clause(ListUpdateClause('a', [], operation='prepend'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s + "a"')
us.add_update(List(Text, db_field='a'), [], 'prepend')
self.assertFalse(us.assignments)
def test_update_list_append_with_empty_list(self):
us = UpdateStatement('table')
us.add_assignment_clause(ListUpdateClause('a', [], operation='append'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
us.add_update(List(Text, db_field='a'), [], 'append')
self.assertFalse(us.assignments)

View File

@@ -74,7 +74,7 @@ class RoutingTests(unittest.TestCase):
select = s.prepare("SELECT token(%s) FROM %s WHERE %s" %
(primary_key, table_name, where_clause))
return (insert, select)
return insert, select
def test_singular_key(self):
# string