Merge pull request #523 from datastax/272
PYTHON-272 - Make Token function work with cqlengine NamedTable
This commit is contained in:
@@ -881,12 +881,3 @@ class _PartitionKeysToken(Column):
|
|||||||
@property
|
@property
|
||||||
def db_field_name(self):
|
def db_field_name(self):
|
||||||
return 'token({0})'.format(', '.join(['"{0}"'.format(c.db_field_name) for c in self.partition_columns]))
|
return 'token({0})'.format(', '.join(['"{0}"'.format(c.db_field_name) for c in self.partition_columns]))
|
||||||
|
|
||||||
def to_database(self, value):
|
|
||||||
from cqlengine.functions import Token
|
|
||||||
assert isinstance(value, Token)
|
|
||||||
value.set_columns(self.partition_columns)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def get_cql(self):
|
|
||||||
return "token({0})".format(", ".join(c.cql for c in self.partition_columns))
|
|
||||||
|
|||||||
@@ -100,14 +100,12 @@ class MaxTimeUUID(TimeUUIDQueryFunction):
|
|||||||
format_string = 'MaxTimeUUID(%({0})s)'
|
format_string = 'MaxTimeUUID(%({0})s)'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseQueryFunction):
|
class Token(BaseQueryFunction):
|
||||||
"""
|
"""
|
||||||
compute the token for a given partition key
|
compute the token for a given partition key
|
||||||
|
|
||||||
http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun
|
http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *values):
|
def __init__(self, *values):
|
||||||
if len(values) == 1 and isinstance(values[0], (list, tuple)):
|
if len(values) == 1 and isinstance(values[0], (list, tuple)):
|
||||||
values = values[0]
|
values = values[0]
|
||||||
|
|||||||
@@ -12,7 +12,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from cassandra.util import OrderedDict
|
||||||
|
|
||||||
from cassandra.cqlengine import CQLEngineException
|
from cassandra.cqlengine import CQLEngineException
|
||||||
|
from cassandra.cqlengine.columns import Column
|
||||||
|
from cassandra.cqlengine.connection import get_cluster
|
||||||
from cassandra.cqlengine.query import AbstractQueryableColumn, SimpleQuerySet
|
from cassandra.cqlengine.query import AbstractQueryableColumn, SimpleQuerySet
|
||||||
from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
|
from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
|
||||||
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
|
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
|
||||||
@@ -78,6 +82,8 @@ class NamedTable(object):
|
|||||||
|
|
||||||
objects = QuerySetDescriptor()
|
objects = QuerySetDescriptor()
|
||||||
|
|
||||||
|
__partition_keys = None
|
||||||
|
|
||||||
class DoesNotExist(_DoesNotExist):
|
class DoesNotExist(_DoesNotExist):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -88,6 +94,20 @@ class NamedTable(object):
|
|||||||
self.keyspace = keyspace
|
self.keyspace = keyspace
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _partition_keys(self):
|
||||||
|
if not self.__partition_keys:
|
||||||
|
self._get_partition_keys()
|
||||||
|
return self.__partition_keys
|
||||||
|
|
||||||
|
def _get_partition_keys(self):
|
||||||
|
try:
|
||||||
|
table_meta = get_cluster().metadata.keyspaces[self.keyspace].tables[self.name]
|
||||||
|
self.__partition_keys = OrderedDict((pk.name, Column(primary_key=True, partition_key=True, db_field=pk.name)) for pk in table_meta.partition_key)
|
||||||
|
except Exception as e:
|
||||||
|
raise CQLEngineException("Failed inspecting partition keys for {0}."
|
||||||
|
"Ensure cqlengine is connected before attempting this with NamedTable.".format(self.column_family_name()))
|
||||||
|
|
||||||
def column(self, name):
|
def column(self, name):
|
||||||
return NamedColumn(name)
|
return NamedColumn(name)
|
||||||
|
|
||||||
|
|||||||
@@ -550,28 +550,15 @@ class AbstractQuerySet(object):
|
|||||||
clone._conditional.append(operator)
|
clone._conditional.append(operator)
|
||||||
|
|
||||||
for col_name, val in kwargs.items():
|
for col_name, val in kwargs.items():
|
||||||
exists = False
|
if isinstance(val, Token):
|
||||||
|
raise QueryException("Token() values are not valid in conditionals")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
column = self.model._get_column(col_name)
|
column = self.model._get_column(col_name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if col_name == 'pk__token':
|
raise QueryException("Can't resolve column name: '{0}'".format(col_name))
|
||||||
if not isinstance(val, Token):
|
|
||||||
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
|
|
||||||
column = columns._PartitionKeysToken(self.model)
|
|
||||||
else:
|
|
||||||
raise QueryException("Can't resolve column name: '{0}'".format(col_name))
|
|
||||||
|
|
||||||
if isinstance(val, Token):
|
if isinstance(val, BaseQueryFunction):
|
||||||
if col_name != 'pk__token':
|
|
||||||
raise QueryException("Token() values may only be compared to the 'pk__token' virtual column")
|
|
||||||
partition_columns = column.partition_columns
|
|
||||||
if len(partition_columns) != len(val.value):
|
|
||||||
raise QueryException(
|
|
||||||
'Token() received {0} arguments but model has {1} partition keys'.format(
|
|
||||||
len(val.value), len(partition_columns)))
|
|
||||||
val.set_columns(partition_columns)
|
|
||||||
|
|
||||||
if isinstance(val, BaseQueryFunction) or exists is True:
|
|
||||||
query_val = val
|
query_val = val
|
||||||
else:
|
else:
|
||||||
query_val = column.to_database(val)
|
query_val = column.to_database(val)
|
||||||
@@ -601,21 +588,19 @@ class AbstractQuerySet(object):
|
|||||||
for arg, val in kwargs.items():
|
for arg, val in kwargs.items():
|
||||||
col_name, col_op = self._parse_filter_arg(arg)
|
col_name, col_op = self._parse_filter_arg(arg)
|
||||||
quote_field = True
|
quote_field = True
|
||||||
# resolve column and operator
|
|
||||||
try:
|
|
||||||
column = self.model._get_column(col_name)
|
|
||||||
except KeyError:
|
|
||||||
if col_name == 'pk__token':
|
|
||||||
if not isinstance(val, Token):
|
|
||||||
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
|
|
||||||
column = columns._PartitionKeysToken(self.model)
|
|
||||||
quote_field = False
|
|
||||||
else:
|
|
||||||
raise QueryException("Can't resolve column name: '{0}'".format(col_name))
|
|
||||||
|
|
||||||
if isinstance(val, Token):
|
if not isinstance(val, Token):
|
||||||
|
try:
|
||||||
|
column = self.model._get_column(col_name)
|
||||||
|
except KeyError:
|
||||||
|
raise QueryException("Can't resolve column name: '{0}'".format(col_name))
|
||||||
|
else:
|
||||||
if col_name != 'pk__token':
|
if col_name != 'pk__token':
|
||||||
raise QueryException("Token() values may only be compared to the 'pk__token' virtual column")
|
raise QueryException("Token() values may only be compared to the 'pk__token' virtual column")
|
||||||
|
|
||||||
|
column = columns._PartitionKeysToken(self.model)
|
||||||
|
quote_field = False
|
||||||
|
|
||||||
partition_columns = column.partition_columns
|
partition_columns = column.partition_columns
|
||||||
if len(partition_columns) != len(val.value):
|
if len(partition_columns) != len(val.value):
|
||||||
raise QueryException(
|
raise QueryException(
|
||||||
@@ -955,13 +940,13 @@ class ModelQuerySet(AbstractQuerySet):
|
|||||||
# check that there's either a =, a IN or a CONTAINS (collection) relationship with a primary key or indexed field
|
# check that there's either a =, a IN or a CONTAINS (collection) relationship with a primary key or indexed field
|
||||||
equal_ops = [self.model._get_column_by_db_name(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)]
|
equal_ops = [self.model._get_column_by_db_name(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)]
|
||||||
token_comparison = any([w for w in self._where if isinstance(w.value, Token)])
|
token_comparison = any([w for w in self._where if isinstance(w.value, Token)])
|
||||||
if not any([w.primary_key or w.index for w in equal_ops]) and not token_comparison and not self._allow_filtering:
|
if not any(w.primary_key or w.index for w in equal_ops) and not token_comparison and not self._allow_filtering:
|
||||||
raise QueryException(('Where clauses require either =, a IN or a CONTAINS (collection) '
|
raise QueryException(('Where clauses require either =, a IN or a CONTAINS (collection) '
|
||||||
'comparison with either a primary key or indexed field'))
|
'comparison with either a primary key or indexed field'))
|
||||||
|
|
||||||
if not self._allow_filtering:
|
if not self._allow_filtering:
|
||||||
# if the query is not on an indexed field
|
# if the query is not on an indexed field
|
||||||
if not any([w.index for w in equal_ops]):
|
if not any(w.index for w in equal_ops):
|
||||||
if not any([w.partition_key for w in equal_ops]) and not token_comparison:
|
if not any([w.partition_key for w in equal_ops]) and not token_comparison:
|
||||||
raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset')
|
raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user