PEP-8 cleanup
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
#column field types
|
||||
from copy import deepcopy, copy
|
||||
from datetime import datetime
|
||||
from datetime import date
|
||||
@@ -86,7 +85,7 @@ class ValueQuoter(object):
|
||||
|
||||
class Column(object):
|
||||
|
||||
#the cassandra type this column maps to
|
||||
# the cassandra type this column maps to
|
||||
db_type = None
|
||||
value_manager = BaseValueManager
|
||||
|
||||
@@ -161,13 +160,13 @@ class Column(object):
|
||||
self.required = required
|
||||
self.clustering_order = clustering_order
|
||||
self.polymorphic_key = polymorphic_key
|
||||
#the column name in the model definition
|
||||
# the column name in the model definition
|
||||
self.column_name = None
|
||||
self.static = static
|
||||
|
||||
self.value = None
|
||||
|
||||
#keep track of instantiation order
|
||||
# keep track of instantiation order
|
||||
self.position = Column.instance_counter
|
||||
Column.instance_counter += 1
|
||||
|
||||
@@ -266,17 +265,18 @@ class Blob(Column):
|
||||
return bytearray(val)
|
||||
|
||||
def to_python(self, value):
|
||||
#return value[2:].decode('hex')
|
||||
return value
|
||||
|
||||
Bytes = Blob
|
||||
|
||||
|
||||
class Ascii(Column):
|
||||
"""
|
||||
Stores a US-ASCII character string
|
||||
"""
|
||||
db_type = 'ascii'
|
||||
|
||||
|
||||
class Inet(Column):
|
||||
"""
|
||||
Stores an IP address in IPv4 or IPv6 format
|
||||
@@ -309,7 +309,8 @@ class Text(Column):
|
||||
|
||||
def validate(self, value):
|
||||
value = super(Text, self).validate(value)
|
||||
if value is None: return
|
||||
if value is None:
|
||||
return
|
||||
if not isinstance(value, (six.string_types, bytearray)) and value is not None:
|
||||
raise ValidationError('{} {} is not a string'.format(self.column_name, type(value)))
|
||||
if self.max_length:
|
||||
@@ -330,7 +331,8 @@ class Integer(Column):
|
||||
|
||||
def validate(self, value):
|
||||
val = super(Integer, self).validate(value)
|
||||
if val is None: return
|
||||
if val is None:
|
||||
return
|
||||
try:
|
||||
return int(val)
|
||||
except (TypeError, ValueError):
|
||||
@@ -409,7 +411,8 @@ class DateTime(Column):
|
||||
db_type = 'timestamp'
|
||||
|
||||
def to_python(self, value):
|
||||
if value is None: return
|
||||
if value is None:
|
||||
return
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
elif isinstance(value, date):
|
||||
@@ -421,7 +424,8 @@ class DateTime(Column):
|
||||
|
||||
def to_database(self, value):
|
||||
value = super(DateTime, self).to_database(value)
|
||||
if value is None: return
|
||||
if value is None:
|
||||
return
|
||||
if not isinstance(value, datetime):
|
||||
if isinstance(value, date):
|
||||
value = datetime(value.year, value.month, value.day)
|
||||
@@ -443,7 +447,8 @@ class Date(Column):
|
||||
db_type = 'timestamp'
|
||||
|
||||
def to_python(self, value):
|
||||
if value is None: return
|
||||
if value is None:
|
||||
return
|
||||
if isinstance(value, datetime):
|
||||
return value.date()
|
||||
elif isinstance(value, date):
|
||||
@@ -455,7 +460,8 @@ class Date(Column):
|
||||
|
||||
def to_database(self, value):
|
||||
value = super(Date, self).to_database(value)
|
||||
if value is None: return
|
||||
if value is None:
|
||||
return
|
||||
if isinstance(value, datetime):
|
||||
value = value.date()
|
||||
if not isinstance(value, date):
|
||||
@@ -474,9 +480,11 @@ class UUID(Column):
|
||||
|
||||
def validate(self, value):
|
||||
val = super(UUID, self).validate(value)
|
||||
if val is None: return
|
||||
if val is None:
|
||||
return
|
||||
from uuid import UUID as _UUID
|
||||
if isinstance(val, _UUID): return val
|
||||
if isinstance(val, _UUID):
|
||||
return val
|
||||
if isinstance(val, six.string_types) and self.re_uuid.match(val):
|
||||
return _UUID(val)
|
||||
raise ValidationError("{} {} is not a valid uuid".format(self.column_name, value))
|
||||
@@ -510,7 +518,7 @@ class TimeUUID(UUID):
|
||||
|
||||
epoch = datetime(1970, 1, 1, tzinfo=dt.tzinfo)
|
||||
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
|
||||
timestamp = (dt - epoch).total_seconds() - offset
|
||||
timestamp = (dt - epoch).total_seconds() - offset
|
||||
|
||||
node = None
|
||||
clock_seq = None
|
||||
@@ -529,7 +537,7 @@ class TimeUUID(UUID):
|
||||
if node is None:
|
||||
node = getnode()
|
||||
return pyUUID(fields=(time_low, time_mid, time_hi_version,
|
||||
clock_seq_hi_variant, clock_seq_low, node), version=1)
|
||||
clock_seq_hi_variant, clock_seq_low, node), version=1)
|
||||
|
||||
|
||||
class Boolean(Column):
|
||||
@@ -563,7 +571,8 @@ class Float(Column):
|
||||
|
||||
def validate(self, value):
|
||||
value = super(Float, self).validate(value)
|
||||
if value is None: return
|
||||
if value is None:
|
||||
return
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
@@ -586,7 +595,8 @@ class Decimal(Column):
|
||||
from decimal import Decimal as _Decimal
|
||||
from decimal import InvalidOperation
|
||||
val = super(Decimal, self).validate(value)
|
||||
if val is None: return
|
||||
if val is None:
|
||||
return
|
||||
try:
|
||||
return _Decimal(val)
|
||||
except InvalidOperation:
|
||||
@@ -679,7 +689,8 @@ class Set(BaseContainerColumn):
|
||||
|
||||
def validate(self, value):
|
||||
val = super(Set, self).validate(value)
|
||||
if val is None: return
|
||||
if val is None:
|
||||
return
|
||||
types = (set,) if self.strict else (set, list, tuple)
|
||||
if not isinstance(val, types):
|
||||
if self.strict:
|
||||
@@ -693,18 +704,19 @@ class Set(BaseContainerColumn):
|
||||
return {self.value_col.validate(v) for v in val}
|
||||
|
||||
def to_python(self, value):
|
||||
if value is None: return set()
|
||||
if value is None:
|
||||
return set()
|
||||
return {self.value_col.to_python(v) for v in value}
|
||||
|
||||
def to_database(self, value):
|
||||
if value is None: return None
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, self.Quoter): return value
|
||||
if isinstance(value, self.Quoter):
|
||||
return value
|
||||
return self.Quoter({self.value_col.to_database(v) for v in value})
|
||||
|
||||
|
||||
|
||||
|
||||
class List(BaseContainerColumn):
|
||||
"""
|
||||
Stores a list of ordered values
|
||||
@@ -727,7 +739,8 @@ class List(BaseContainerColumn):
|
||||
|
||||
def validate(self, value):
|
||||
val = super(List, self).validate(value)
|
||||
if val is None: return
|
||||
if val is None:
|
||||
return
|
||||
if not isinstance(val, (set, list, tuple)):
|
||||
raise ValidationError('{} {} is not a list object'.format(self.column_name, val))
|
||||
if None in val:
|
||||
@@ -735,12 +748,15 @@ class List(BaseContainerColumn):
|
||||
return [self.value_col.validate(v) for v in val]
|
||||
|
||||
def to_python(self, value):
|
||||
if value is None: return []
|
||||
if value is None:
|
||||
return []
|
||||
return [self.value_col.to_python(v) for v in value]
|
||||
|
||||
def to_database(self, value):
|
||||
if value is None: return None
|
||||
if isinstance(value, self.Quoter): return value
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self.Quoter):
|
||||
return value
|
||||
return self.Quoter([self.value_col.to_database(v) for v in value])
|
||||
|
||||
|
||||
@@ -757,7 +773,7 @@ class Map(BaseContainerColumn):
|
||||
|
||||
def __str__(self):
|
||||
cq = cql_quote
|
||||
return '{' + ', '.join([cq(k) + ':' + cq(v) for k,v in self.value.items()]) + '}'
|
||||
return '{' + ', '.join([cq(k) + ':' + cq(v) for k, v in self.value.items()]) + '}'
|
||||
|
||||
def get(self, key):
|
||||
return self.value.get(key)
|
||||
@@ -802,22 +818,24 @@ class Map(BaseContainerColumn):
|
||||
|
||||
def validate(self, value):
|
||||
val = super(Map, self).validate(value)
|
||||
if val is None: return
|
||||
if val is None:
|
||||
return
|
||||
if not isinstance(val, dict):
|
||||
raise ValidationError('{} {} is not a dict object'.format(self.column_name, val))
|
||||
return {self.key_col.validate(k):self.value_col.validate(v) for k,v in val.items()}
|
||||
return {self.key_col.validate(k): self.value_col.validate(v) for k, v in val.items()}
|
||||
|
||||
def to_python(self, value):
|
||||
if value is None:
|
||||
return {}
|
||||
if value is not None:
|
||||
return {self.key_col.to_python(k): self.value_col.to_python(v) for k,v in value.items()}
|
||||
return {self.key_col.to_python(k): self.value_col.to_python(v) for k, v in value.items()}
|
||||
|
||||
def to_database(self, value):
|
||||
if value is None: return None
|
||||
if isinstance(value, self.Quoter): return value
|
||||
return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()})
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self.Quoter):
|
||||
return value
|
||||
return self.Quoter({self.key_col.to_database(k): self.value_col.to_database(v) for k, v in value.items()})
|
||||
|
||||
|
||||
class _PartitionKeysToken(Column):
|
||||
@@ -842,4 +860,3 @@ class _PartitionKeysToken(Column):
|
||||
|
||||
def get_cql(self):
|
||||
return "token({})".format(", ".join(c.cql for c in self.partition_columns))
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections import namedtuple
|
||||
import six
|
||||
import logging
|
||||
import six
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
from cassandra.cluster import Cluster, _NOT_SET, NoHostAvailable
|
||||
@@ -9,11 +9,13 @@ from cassandra.cqlengine.exceptions import CQLEngineException, UndefinedKeyspace
|
||||
from cassandra.cqlengine.statements import BaseCQLStatement
|
||||
|
||||
|
||||
LOG = logging.getLogger('cqlengine.cql')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
NOT_SET = _NOT_SET # required for passing timeout to Session.execute
|
||||
|
||||
|
||||
class CQLConnectionError(CQLEngineException): pass
|
||||
class CQLConnectionError(CQLEngineException):
|
||||
pass
|
||||
|
||||
Host = namedtuple('Host', ['name', 'port'])
|
||||
|
||||
@@ -22,6 +24,7 @@ session = None
|
||||
lazy_connect_args = None
|
||||
default_consistency_level = None
|
||||
|
||||
|
||||
def setup(
|
||||
hosts,
|
||||
default_keyspace,
|
||||
@@ -94,21 +97,24 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET):
|
||||
elif isinstance(query, six.string_types):
|
||||
query = SimpleStatement(query, consistency_level=consistency_level)
|
||||
|
||||
LOG.info(query.query_string)
|
||||
log.debug(query.query_string)
|
||||
|
||||
params = params or {}
|
||||
result = session.execute(query, params, timeout=timeout)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_session():
|
||||
handle_lazy_connect()
|
||||
return session
|
||||
|
||||
|
||||
def get_cluster():
|
||||
handle_lazy_connect()
|
||||
return cluster
|
||||
|
||||
|
||||
def handle_lazy_connect():
|
||||
global lazy_connect_args
|
||||
if lazy_connect_args:
|
||||
|
||||
@@ -1,8 +1,22 @@
|
||||
#cqlengine exceptions
|
||||
class CQLEngineException(Exception): pass
|
||||
class ModelException(CQLEngineException): pass
|
||||
class ValidationError(CQLEngineException): pass
|
||||
class CQLEngineException(Exception):
|
||||
pass
|
||||
|
||||
class UndefinedKeyspaceException(CQLEngineException): pass
|
||||
class LWTException(CQLEngineException): pass
|
||||
class IfNotExistsWithCounterColumn(CQLEngineException): pass
|
||||
|
||||
class ModelException(CQLEngineException):
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(CQLEngineException):
|
||||
pass
|
||||
|
||||
|
||||
class UndefinedKeyspaceException(CQLEngineException):
|
||||
pass
|
||||
|
||||
|
||||
class LWTException(CQLEngineException):
|
||||
pass
|
||||
|
||||
|
||||
class IfNotExistsWithCounterColumn(CQLEngineException):
|
||||
pass
|
||||
|
||||
@@ -5,12 +5,14 @@ import sys
|
||||
from cassandra.cqlengine.exceptions import ValidationError
|
||||
# move to central spot
|
||||
|
||||
|
||||
class UnicodeMixin(object):
|
||||
if sys.version_info > (3, 0):
|
||||
__str__ = lambda x: x.__unicode__()
|
||||
else:
|
||||
__str__ = lambda x: six.text_type(x).encode('utf-8')
|
||||
|
||||
|
||||
class QueryValue(UnicodeMixin):
|
||||
"""
|
||||
Base class for query filter values. Subclasses of these classes can
|
||||
@@ -42,6 +44,8 @@ class BaseQueryFunction(QueryValue):
|
||||
be passed into .filter() and will be translated into CQL functions in
|
||||
the resulting query
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MinTimeUUID(BaseQueryFunction):
|
||||
"""
|
||||
@@ -123,4 +127,3 @@ class Token(BaseQueryFunction):
|
||||
def update_context(self, ctx):
|
||||
for i, (col, val) in enumerate(zip(self._columns, self.value)):
|
||||
ctx[str(self.context_id + i)] = col.to_database(val)
|
||||
|
||||
|
||||
@@ -12,11 +12,12 @@ from cassandra.cqlengine.named import NamedTable
|
||||
|
||||
Field = namedtuple('Field', ['name', 'type'])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# system keyspaces
|
||||
schema_columnfamilies = NamedTable('system', 'schema_columnfamilies')
|
||||
|
||||
|
||||
def create_keyspace(name, strategy_class, replication_factor, durable_writes=True, **replication_values):
|
||||
"""
|
||||
*Deprecated - this will likely be repaced with something specialized per replication strategy.*
|
||||
@@ -38,10 +39,10 @@ def create_keyspace(name, strategy_class, replication_factor, durable_writes=Tru
|
||||
cluster = get_cluster()
|
||||
|
||||
if name not in cluster.metadata.keyspaces:
|
||||
#try the 1.2 method
|
||||
# try the 1.2 method
|
||||
replication_map = {
|
||||
'class': strategy_class,
|
||||
'replication_factor':replication_factor
|
||||
'replication_factor': replication_factor
|
||||
}
|
||||
replication_map.update(replication_values)
|
||||
if strategy_class.lower() != 'simplestrategy':
|
||||
@@ -76,6 +77,7 @@ def delete_keyspace(name):
|
||||
if name in cluster.metadata.keyspaces:
|
||||
execute("DROP KEYSPACE {}".format(name))
|
||||
|
||||
|
||||
def sync_table(model):
|
||||
"""
|
||||
Inspects the model and creates / updates the corresponding table and columns.
|
||||
@@ -95,8 +97,7 @@ def sync_table(model):
|
||||
if model.__abstract__:
|
||||
raise CQLEngineException("cannot create table from abstract model")
|
||||
|
||||
|
||||
#construct query string
|
||||
# construct query string
|
||||
cf_name = model.column_family_name()
|
||||
raw_cf_name = model.column_family_name(include_keyspace=False)
|
||||
|
||||
@@ -107,7 +108,7 @@ def sync_table(model):
|
||||
keyspace = cluster.metadata.keyspaces[ks_name]
|
||||
tables = keyspace.tables
|
||||
|
||||
#check for an existing column family
|
||||
# check for an existing column family
|
||||
if raw_cf_name not in tables:
|
||||
qs = get_create_table(model)
|
||||
|
||||
@@ -123,20 +124,21 @@ def sync_table(model):
|
||||
fields = get_fields(model)
|
||||
field_names = [x.name for x in fields]
|
||||
for name, col in model._columns.items():
|
||||
if col.primary_key or col.partition_key: continue # we can't mess with the PK
|
||||
if col.db_field_name in field_names: continue # skip columns already defined
|
||||
if col.primary_key or col.partition_key:
|
||||
continue # we can't mess with the PK
|
||||
if col.db_field_name in field_names:
|
||||
continue # skip columns already defined
|
||||
|
||||
# add missing column using the column def
|
||||
query = "ALTER TABLE {} add {}".format(cf_name, col.get_column_def())
|
||||
logger.debug(query)
|
||||
log.debug(query)
|
||||
execute(query)
|
||||
|
||||
update_compaction(model)
|
||||
|
||||
|
||||
table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name]
|
||||
|
||||
indexes = [c for n,c in model._columns.items() if c.index]
|
||||
indexes = [c for n, c in model._columns.items() if c.index]
|
||||
|
||||
for column in indexes:
|
||||
if table.columns[column.db_field_name].index:
|
||||
@@ -148,20 +150,23 @@ def sync_table(model):
|
||||
qs = ' '.join(qs)
|
||||
execute(qs)
|
||||
|
||||
|
||||
def get_create_table(model):
|
||||
cf_name = model.column_family_name()
|
||||
qs = ['CREATE TABLE {}'.format(cf_name)]
|
||||
|
||||
#add column types
|
||||
pkeys = [] # primary keys
|
||||
ckeys = [] # clustering keys
|
||||
qtypes = [] # field types
|
||||
# add column types
|
||||
pkeys = [] # primary keys
|
||||
ckeys = [] # clustering keys
|
||||
qtypes = [] # field types
|
||||
|
||||
def add_column(col):
|
||||
s = col.get_column_def()
|
||||
if col.primary_key:
|
||||
keys = (pkeys if col.partition_key else ckeys)
|
||||
keys.append('"{}"'.format(col.db_field_name))
|
||||
qtypes.append(s)
|
||||
|
||||
for name, col in model._columns.items():
|
||||
add_column(col)
|
||||
|
||||
@@ -172,9 +177,9 @@ def get_create_table(model):
|
||||
with_qs = []
|
||||
|
||||
table_properties = ['bloom_filter_fp_chance', 'caching', 'comment',
|
||||
'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds',
|
||||
'index_interval', 'memtable_flush_period_in_ms', 'populate_io_cache_on_flush',
|
||||
'read_repair_chance', 'replicate_on_write']
|
||||
'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds',
|
||||
'index_interval', 'memtable_flush_period_in_ms', 'populate_io_cache_on_flush',
|
||||
'read_repair_chance', 'replicate_on_write']
|
||||
for prop_name in table_properties:
|
||||
prop_value = getattr(model, '__{}__'.format(prop_name), None)
|
||||
if prop_value is not None:
|
||||
@@ -211,9 +216,9 @@ def get_compaction_options(model):
|
||||
if not model.__compaction__:
|
||||
return {}
|
||||
|
||||
result = {'class':model.__compaction__}
|
||||
result = {'class': model.__compaction__}
|
||||
|
||||
def setter(key, limited_to_strategy = None):
|
||||
def setter(key, limited_to_strategy=None):
|
||||
"""
|
||||
sets key in result, checking if the key is limited to either SizeTiered or Leveled
|
||||
:param key: one of the compaction options, like "bucket_high"
|
||||
@@ -270,6 +275,7 @@ def get_table_settings(model):
|
||||
table = cluster.metadata.keyspaces[ks].tables[table]
|
||||
return table
|
||||
|
||||
|
||||
def update_compaction(model):
|
||||
"""Updates the compaction options for the given model if necessary.
|
||||
|
||||
@@ -279,7 +285,7 @@ def update_compaction(model):
|
||||
`False` otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
logger.debug("Checking %s for compaction differences", model)
|
||||
log.debug("Checking %s for compaction differences", model)
|
||||
table = get_table_settings(model)
|
||||
|
||||
existing_options = table.options.copy()
|
||||
@@ -311,7 +317,7 @@ def update_compaction(model):
|
||||
options = json.dumps(options).replace('"', "'")
|
||||
cf_name = model.column_family_name()
|
||||
query = "ALTER TABLE {} with compaction = {}".format(cf_name, options)
|
||||
logger.debug(query)
|
||||
log.debug(query)
|
||||
execute(query)
|
||||
return True
|
||||
|
||||
@@ -339,6 +345,3 @@ def drop_table(model):
|
||||
execute('drop table {};'.format(model.column_family_name(include_keyspace=True)))
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
import six
|
||||
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine.exceptions import ModelException, CQLEngineException, ValidationError
|
||||
@@ -7,16 +8,20 @@ from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
|
||||
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
|
||||
from cassandra.util import OrderedDict
|
||||
|
||||
class ModelDefinitionException(ModelException): pass
|
||||
|
||||
class ModelDefinitionException(ModelException):
|
||||
pass
|
||||
|
||||
|
||||
class PolyMorphicModelException(ModelException): pass
|
||||
class PolyMorphicModelException(ModelException):
|
||||
pass
|
||||
|
||||
DEFAULT_KEYSPACE = None
|
||||
|
||||
class UndefinedKeyspaceWarning(Warning):
|
||||
pass
|
||||
|
||||
DEFAULT_KEYSPACE = None
|
||||
|
||||
|
||||
class hybrid_classmethod(object):
|
||||
"""
|
||||
@@ -106,7 +111,7 @@ class TTLDescriptor(object):
|
||||
"""
|
||||
def __get__(self, instance, model):
|
||||
if instance:
|
||||
#instance = copy.deepcopy(instance)
|
||||
# instance = copy.deepcopy(instance)
|
||||
# instance method
|
||||
def ttl_setter(ts):
|
||||
instance._ttl = ts
|
||||
@@ -124,6 +129,7 @@ class TTLDescriptor(object):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TimestampDescriptor(object):
|
||||
"""
|
||||
returns a query set descriptor with a timestamp specified
|
||||
@@ -138,10 +144,10 @@ class TimestampDescriptor(object):
|
||||
|
||||
return model.objects.timestamp
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IfNotExistsDescriptor(object):
|
||||
"""
|
||||
return a query set descriptor with a if_not_exists flag specified
|
||||
@@ -159,13 +165,14 @@ class IfNotExistsDescriptor(object):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ConsistencyDescriptor(object):
|
||||
"""
|
||||
returns a query set descriptor if called on Class, instance if it was an instance call
|
||||
"""
|
||||
def __get__(self, instance, model):
|
||||
if instance:
|
||||
#instance = copy.deepcopy(instance)
|
||||
# instance = copy.deepcopy(instance)
|
||||
def consistency_setter(consistency):
|
||||
instance.__consistency__ = consistency
|
||||
return instance
|
||||
@@ -229,7 +236,7 @@ class ColumnDescriptor(object):
|
||||
"""
|
||||
try:
|
||||
return instance._values[self.column.column_name].getval()
|
||||
except AttributeError as e:
|
||||
except AttributeError:
|
||||
return self.query_evaluator
|
||||
|
||||
def __set__(self, instance, value):
|
||||
@@ -258,9 +265,11 @@ class BaseModel(object):
|
||||
The base model class, don't inherit from this, inherit from Model, defined below
|
||||
"""
|
||||
|
||||
class DoesNotExist(_DoesNotExist): pass
|
||||
class DoesNotExist(_DoesNotExist):
|
||||
pass
|
||||
|
||||
class MultipleObjectsReturned(_MultipleObjectsReturned): pass
|
||||
class MultipleObjectsReturned(_MultipleObjectsReturned):
|
||||
pass
|
||||
|
||||
objects = QuerySetDescriptor()
|
||||
ttl = TTLDescriptor()
|
||||
@@ -269,7 +278,7 @@ class BaseModel(object):
|
||||
|
||||
# custom timestamps, see USING TIMESTAMP X
|
||||
timestamp = TimestampDescriptor()
|
||||
|
||||
|
||||
if_not_exists = IfNotExistsDescriptor()
|
||||
|
||||
# _len is lazily created by __len__
|
||||
@@ -302,8 +311,7 @@ class BaseModel(object):
|
||||
__queryset__ = ModelQuerySet
|
||||
__dmlquery__ = DMLQuery
|
||||
|
||||
__consistency__ = None # can be set per query
|
||||
|
||||
__consistency__ = None # can be set per query
|
||||
|
||||
# Additional table properties
|
||||
__bloom_filter_fp_chance__ = None
|
||||
@@ -318,9 +326,9 @@ class BaseModel(object):
|
||||
__read_repair_chance__ = None
|
||||
__replicate_on_write__ = None
|
||||
|
||||
_timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP)
|
||||
_timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP)
|
||||
|
||||
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
|
||||
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
|
||||
|
||||
def __init__(self, **values):
|
||||
self._values = {}
|
||||
@@ -343,21 +351,19 @@ class BaseModel(object):
|
||||
self._batch = None
|
||||
self._timeout = NOT_SET
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Pretty printing of models by their primary key
|
||||
"""
|
||||
return '{} <{}>'.format(self.__class__.__name__,
|
||||
', '.join(('{}={}'.format(k, getattr(self, k)) for k,v in six.iteritems(self._primary_keys)))
|
||||
', '.join(('{}={}'.format(k, getattr(self, k)) for k, v in six.iteritems(self._primary_keys)))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def _discover_polymorphic_submodels(cls):
|
||||
if not cls._is_polymorphic_base:
|
||||
raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes')
|
||||
|
||||
def _discover(klass):
|
||||
if not klass._is_polymorphic_base and klass.__polymorphic_key__ is not None:
|
||||
cls._polymorphic_map[klass.__polymorphic_key__] = klass
|
||||
@@ -381,7 +387,7 @@ class BaseModel(object):
|
||||
# and translate that into our local fields
|
||||
# the db_map is a db_field -> model field map
|
||||
items = values.items()
|
||||
field_dict = dict([(cls._db_map.get(k, k),v) for k,v in items])
|
||||
field_dict = dict([(cls._db_map.get(k, k), v) for k, v in items])
|
||||
|
||||
if cls._is_polymorphic:
|
||||
poly_key = field_dict.get(cls._polymorphic_column_name)
|
||||
@@ -421,8 +427,9 @@ class BaseModel(object):
|
||||
|
||||
:return:
|
||||
"""
|
||||
if not self._is_persisted: return False
|
||||
pks = self._primary_keys.keys()
|
||||
if not self._is_persisted:
|
||||
return False
|
||||
|
||||
return all([not self._values[k].changed for k in self._primary_keys])
|
||||
|
||||
@classmethod
|
||||
@@ -481,11 +488,14 @@ class BaseModel(object):
|
||||
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
|
||||
|
||||
cf_name += ccase(cls.__name__)
|
||||
#trim to less than 48 characters or cassandra will complain
|
||||
# trim to less than 48 characters or cassandra will complain
|
||||
cf_name = cf_name[-48:]
|
||||
cf_name = cf_name.lower()
|
||||
cf_name = re.sub(r'^_+', '', cf_name)
|
||||
if not include_keyspace: return cf_name
|
||||
|
||||
if not include_keyspace:
|
||||
return cf_name
|
||||
|
||||
return '{}.{}'.format(cls._get_keyspace(), cf_name)
|
||||
|
||||
def validate(self):
|
||||
@@ -499,8 +509,7 @@ class BaseModel(object):
|
||||
val = col.validate(v)
|
||||
setattr(self, name, val)
|
||||
|
||||
### Let an instance be used like a dict of its columns keys/values
|
||||
|
||||
# Let an instance be used like a dict of its columns keys/values
|
||||
def __iter__(self):
|
||||
""" Iterate over column ids. """
|
||||
for column_id in self._columns.keys():
|
||||
@@ -620,7 +629,6 @@ class BaseModel(object):
|
||||
else:
|
||||
setattr(self, self._polymorphic_column_name, self.__polymorphic_key__)
|
||||
|
||||
is_new = self.pk is None
|
||||
self.validate()
|
||||
self.__dmlquery__(self.__class__, self,
|
||||
batch=self._batch,
|
||||
@@ -631,7 +639,7 @@ class BaseModel(object):
|
||||
transaction=self._transaction,
|
||||
timeout=self._timeout).save()
|
||||
|
||||
#reset the value managers
|
||||
# reset the value managers
|
||||
for v in self._values.values():
|
||||
v.reset_previous_value()
|
||||
self._is_persisted = True
|
||||
@@ -680,7 +688,7 @@ class BaseModel(object):
|
||||
transaction=self._transaction,
|
||||
timeout=self._timeout).update()
|
||||
|
||||
#reset the value managers
|
||||
# reset the value managers
|
||||
for v in self._values.values():
|
||||
v.reset_previous_value()
|
||||
self._is_persisted = True
|
||||
@@ -704,7 +712,7 @@ class BaseModel(object):
|
||||
"""
|
||||
Returns a list of the columns that have been updated since instantiation or save
|
||||
"""
|
||||
return [k for k,v in self._values.items() if v.changed]
|
||||
return [k for k, v in self._values.items() if v.changed]
|
||||
|
||||
@classmethod
|
||||
def _class_batch(cls, batch):
|
||||
@@ -715,29 +723,28 @@ class BaseModel(object):
|
||||
self._batch = batch
|
||||
return self
|
||||
|
||||
|
||||
batch = hybrid_classmethod(_class_batch, _inst_batch)
|
||||
|
||||
|
||||
class ModelMetaClass(type):
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
#move column definitions into columns dict
|
||||
#and set default column names
|
||||
# move column definitions into columns dict
|
||||
# and set default column names
|
||||
column_dict = OrderedDict()
|
||||
primary_keys = OrderedDict()
|
||||
pk_name = None
|
||||
|
||||
#get inherited properties
|
||||
# get inherited properties
|
||||
inherited_columns = OrderedDict()
|
||||
for base in bases:
|
||||
for k,v in getattr(base, '_defined_columns', {}).items():
|
||||
inherited_columns.setdefault(k,v)
|
||||
for k, v in getattr(base, '_defined_columns', {}).items():
|
||||
inherited_columns.setdefault(k, v)
|
||||
|
||||
#short circuit __abstract__ inheritance
|
||||
# short circuit __abstract__ inheritance
|
||||
is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False)
|
||||
|
||||
#short circuit __polymorphic_key__ inheritance
|
||||
# short circuit __polymorphic_key__ inheritance
|
||||
attrs['__polymorphic_key__'] = attrs.get('__polymorphic_key__', None)
|
||||
|
||||
def _transform_column(col_name, col_obj):
|
||||
@@ -745,11 +752,10 @@ class ModelMetaClass(type):
|
||||
if col_obj.primary_key:
|
||||
primary_keys[col_name] = col_obj
|
||||
col_obj.set_column_name(col_name)
|
||||
#set properties
|
||||
# set properties
|
||||
attrs[col_name] = ColumnDescriptor(col_obj)
|
||||
|
||||
column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.Column)]
|
||||
#column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position))
|
||||
column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)]
|
||||
column_definitions = sorted(column_definitions, key=lambda x: x[1].position)
|
||||
|
||||
is_polymorphic_base = any([c[1].polymorphic_key for c in column_definitions])
|
||||
@@ -780,7 +786,7 @@ class ModelMetaClass(type):
|
||||
defined_columns = OrderedDict(column_definitions)
|
||||
|
||||
# check for primary key
|
||||
if not is_abstract and not any([v.primary_key for k,v in column_definitions]):
|
||||
if not is_abstract and not any([v.primary_key for k, v in column_definitions]):
|
||||
raise ModelDefinitionException("At least 1 primary key is required.")
|
||||
|
||||
counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)]
|
||||
@@ -790,7 +796,7 @@ class ModelMetaClass(type):
|
||||
|
||||
has_partition_keys = any(v.partition_key for (k, v) in column_definitions)
|
||||
|
||||
#transform column definitions
|
||||
# transform column definitions
|
||||
for k, v in column_definitions:
|
||||
# don't allow a column with the same name as a built-in attribute or method
|
||||
if k in BaseModel.__dict__:
|
||||
@@ -810,7 +816,7 @@ class ModelMetaClass(type):
|
||||
partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key)
|
||||
clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key)
|
||||
|
||||
#setup partition key shortcut
|
||||
# setup partition key shortcut
|
||||
if len(partition_keys) == 0:
|
||||
if not is_abstract:
|
||||
raise ModelException("at least one partition key must be defined")
|
||||
@@ -835,12 +841,12 @@ class ModelMetaClass(type):
|
||||
raise ModelException("invalid clustering order {} for column {}".format(repr(v.clustering_order), v.db_field_name))
|
||||
col_names.add(v.db_field_name)
|
||||
|
||||
#create db_name -> model name map for loading
|
||||
# create db_name -> model name map for loading
|
||||
db_map = {}
|
||||
for field_name, col in column_dict.items():
|
||||
db_map[col.db_field_name] = field_name
|
||||
|
||||
#add management members to the class
|
||||
# add management members to the class
|
||||
attrs['_columns'] = column_dict
|
||||
attrs['_primary_keys'] = primary_keys
|
||||
attrs['_defined_columns'] = defined_columns
|
||||
@@ -862,29 +868,31 @@ class ModelMetaClass(type):
|
||||
attrs['_polymorphic_column_name'] = polymorphic_column_name
|
||||
attrs['_polymorphic_map'] = {} if is_polymorphic_base else None
|
||||
|
||||
#setup class exceptions
|
||||
# setup class exceptions
|
||||
DoesNotExistBase = None
|
||||
for base in bases:
|
||||
DoesNotExistBase = getattr(base, 'DoesNotExist', None)
|
||||
if DoesNotExistBase is not None: break
|
||||
if DoesNotExistBase is not None:
|
||||
break
|
||||
|
||||
DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist)
|
||||
attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {})
|
||||
|
||||
MultipleObjectsReturnedBase = None
|
||||
for base in bases:
|
||||
MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None)
|
||||
if MultipleObjectsReturnedBase is not None: break
|
||||
if MultipleObjectsReturnedBase is not None:
|
||||
break
|
||||
|
||||
MultipleObjectsReturnedBase = DoesNotExistBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned)
|
||||
attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {})
|
||||
|
||||
#create the class and add a QuerySet to it
|
||||
# create the class and add a QuerySet to it
|
||||
klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs)
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
import six
|
||||
|
||||
@six.add_metaclass(ModelMetaClass)
|
||||
class Model(BaseModel):
|
||||
__abstract__ = True
|
||||
@@ -906,7 +914,7 @@ class Model(BaseModel):
|
||||
__default_ttl__ = None
|
||||
"""
|
||||
*Optional* The default ttl used by this model.
|
||||
|
||||
|
||||
This can be overridden by using the :meth:`~.ttl` method.
|
||||
"""
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from cassandra.cqlengine.query import AbstractQueryableColumn, SimpleQuerySet
|
||||
from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
|
||||
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
|
||||
|
||||
|
||||
class QuerySetDescriptor(object):
|
||||
"""
|
||||
returns a fresh queryset for the given model
|
||||
@@ -63,8 +64,11 @@ class NamedTable(object):
|
||||
|
||||
objects = QuerySetDescriptor()
|
||||
|
||||
class DoesNotExist(_DoesNotExist): pass
|
||||
class MultipleObjectsReturned(_MultipleObjectsReturned): pass
|
||||
class DoesNotExist(_DoesNotExist):
|
||||
pass
|
||||
|
||||
class MultipleObjectsReturned(_MultipleObjectsReturned):
|
||||
pass
|
||||
|
||||
def __init__(self, keyspace, name):
|
||||
self.keyspace = keyspace
|
||||
@@ -118,4 +122,3 @@ class NamedKeyspace(object):
|
||||
name that belongs to this keyspace
|
||||
"""
|
||||
return NamedTable(self.name, name)
|
||||
|
||||
|
||||
@@ -2,7 +2,9 @@ import six
|
||||
import sys
|
||||
|
||||
|
||||
class QueryOperatorException(Exception): pass
|
||||
class QueryOperatorException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# move to central spot
|
||||
class UnicodeMixin(object):
|
||||
@@ -31,12 +33,14 @@ class BaseQueryOperator(UnicodeMixin):
|
||||
raise QueryOperatorException("get_operator can only be called from a BaseQueryOperator subclass")
|
||||
if not hasattr(cls, 'opmap'):
|
||||
cls.opmap = {}
|
||||
|
||||
def _recurse(klass):
|
||||
if klass.symbol:
|
||||
cls.opmap[klass.symbol.upper()] = klass
|
||||
for subklass in klass.__subclasses__():
|
||||
_recurse(subklass)
|
||||
pass
|
||||
|
||||
_recurse(cls)
|
||||
try:
|
||||
return cls.opmap[symbol.upper()]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
import six
|
||||
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine.connection import execute, NOT_SET
|
||||
@@ -17,11 +18,16 @@ from cassandra.cqlengine.statements import (WhereClause, SelectStatement, Delete
|
||||
TransactionClause)
|
||||
|
||||
|
||||
class QueryException(CQLEngineException): pass
|
||||
class DoesNotExist(QueryException): pass
|
||||
class MultipleObjectsReturned(QueryException): pass
|
||||
class QueryException(CQLEngineException):
|
||||
pass
|
||||
|
||||
import six
|
||||
|
||||
class DoesNotExist(QueryException):
|
||||
pass
|
||||
|
||||
|
||||
class MultipleObjectsReturned(QueryException):
|
||||
pass
|
||||
|
||||
|
||||
def check_applied(result):
|
||||
@@ -30,7 +36,7 @@ def check_applied(result):
|
||||
if that value is false, it means our light-weight transaction didn't
|
||||
applied to database.
|
||||
"""
|
||||
if result and '[applied]' in result[0] and result[0]['[applied]'] == False:
|
||||
if result and '[applied]' in result[0] and not result[0]['[applied]']:
|
||||
raise LWTException('')
|
||||
|
||||
|
||||
@@ -77,8 +83,8 @@ class AbstractQueryableColumn(UnicodeMixin):
|
||||
|
||||
|
||||
class BatchType(object):
|
||||
Unlogged = 'UNLOGGED'
|
||||
Counter = 'COUNTER'
|
||||
Unlogged = 'UNLOGGED'
|
||||
Counter = 'COUNTER'
|
||||
|
||||
|
||||
class BatchQuery(object):
|
||||
@@ -194,8 +200,9 @@ class BatchQuery(object):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
#don't execute if there was an exception by default
|
||||
if exc_type is not None and not self._execute_on_exception: return
|
||||
# don't execute if there was an exception by default
|
||||
if exc_type is not None and not self._execute_on_exception:
|
||||
return
|
||||
self.execute()
|
||||
|
||||
|
||||
@@ -205,29 +212,29 @@ class AbstractQuerySet(object):
|
||||
super(AbstractQuerySet, self).__init__()
|
||||
self.model = model
|
||||
|
||||
#Where clause filters
|
||||
# Where clause filters
|
||||
self._where = []
|
||||
|
||||
# Transaction clause filters
|
||||
self._transaction = []
|
||||
|
||||
#ordering arguments
|
||||
# ordering arguments
|
||||
self._order = []
|
||||
|
||||
self._allow_filtering = False
|
||||
|
||||
#CQL has a default limit of 10000, it's defined here
|
||||
#because explicit is better than implicit
|
||||
# CQL has a default limit of 10000, it's defined here
|
||||
# because explicit is better than implicit
|
||||
self._limit = 10000
|
||||
|
||||
#see the defer and only methods
|
||||
# see the defer and only methods
|
||||
self._defer_fields = []
|
||||
self._only_fields = []
|
||||
|
||||
self._values_list = False
|
||||
self._flat_values_list = False
|
||||
|
||||
#results cache
|
||||
# results cache
|
||||
self._con = None
|
||||
self._cur = None
|
||||
self._result_cache = None
|
||||
@@ -265,7 +272,7 @@ class AbstractQuerySet(object):
|
||||
def __deepcopy__(self, memo):
|
||||
clone = self.__class__(self.model)
|
||||
for k, v in self.__dict__.items():
|
||||
if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these
|
||||
if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these
|
||||
clone.__dict__[k] = None
|
||||
elif k == '_batch':
|
||||
# we need to keep the same batch instance across
|
||||
@@ -284,7 +291,7 @@ class AbstractQuerySet(object):
|
||||
self._execute_query()
|
||||
return len(self._result_cache)
|
||||
|
||||
#----query generation / execution----
|
||||
# ----query generation / execution----
|
||||
|
||||
def _select_fields(self):
|
||||
""" returns the fields to select """
|
||||
@@ -308,7 +315,7 @@ class AbstractQuerySet(object):
|
||||
allow_filtering=self._allow_filtering
|
||||
)
|
||||
|
||||
#----Reads------
|
||||
# ----Reads------
|
||||
|
||||
def _execute_query(self):
|
||||
if self._batch:
|
||||
@@ -330,7 +337,7 @@ class AbstractQuerySet(object):
|
||||
self._result_idx += 1
|
||||
self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx])
|
||||
|
||||
#return the connection to the connection pool if we have all objects
|
||||
# return the connection to the connection pool if we have all objects
|
||||
if self._result_cache and self._result_idx == (len(self._result_cache) - 1):
|
||||
self._con = None
|
||||
self._cur = None
|
||||
@@ -350,7 +357,7 @@ class AbstractQuerySet(object):
|
||||
num_results = len(self._result_cache)
|
||||
|
||||
if isinstance(s, slice):
|
||||
#calculate the amount of results that need to be loaded
|
||||
# calculate the amount of results that need to be loaded
|
||||
end = num_results if s.step is None else s.step
|
||||
if end < 0:
|
||||
end += num_results
|
||||
@@ -359,11 +366,12 @@ class AbstractQuerySet(object):
|
||||
self._fill_result_cache_to_idx(end)
|
||||
return self._result_cache[s.start:s.stop:s.step]
|
||||
else:
|
||||
#return the object at this index
|
||||
# return the object at this index
|
||||
s = int(s)
|
||||
|
||||
#handle negative indexing
|
||||
if s < 0: s += num_results
|
||||
# handle negative indexing
|
||||
if s < 0:
|
||||
s += num_results
|
||||
|
||||
if s >= num_results:
|
||||
raise IndexError
|
||||
@@ -453,7 +461,6 @@ class AbstractQuerySet(object):
|
||||
if not isinstance(val, Token):
|
||||
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
|
||||
column = columns._PartitionKeysToken(self.model)
|
||||
quote_field = False
|
||||
else:
|
||||
raise QueryException("Can't resolve column name: '{}'".format(col_name))
|
||||
|
||||
@@ -484,7 +491,7 @@ class AbstractQuerySet(object):
|
||||
|
||||
Returns a QuerySet filtered on the keyword arguments
|
||||
"""
|
||||
#add arguments to the where clause filters
|
||||
# add arguments to the where clause filters
|
||||
if len([x for x in kwargs.values() if x is None]):
|
||||
raise CQLEngineException("None values on filter are not allowed")
|
||||
|
||||
@@ -497,7 +504,7 @@ class AbstractQuerySet(object):
|
||||
for arg, val in kwargs.items():
|
||||
col_name, col_op = self._parse_filter_arg(arg)
|
||||
quote_field = True
|
||||
#resolve column and operator
|
||||
# resolve column and operator
|
||||
try:
|
||||
column = self.model._get_column(col_name)
|
||||
except KeyError:
|
||||
@@ -519,7 +526,7 @@ class AbstractQuerySet(object):
|
||||
len(val.value), len(partition_columns)))
|
||||
val.set_columns(partition_columns)
|
||||
|
||||
#get query operator, or use equals if not supplied
|
||||
# get query operator, or use equals if not supplied
|
||||
operator_class = BaseWhereOperator.get_operator(col_op or 'EQ')
|
||||
operator = operator_class()
|
||||
|
||||
@@ -559,8 +566,7 @@ class AbstractQuerySet(object):
|
||||
if len(self._result_cache) == 0:
|
||||
raise self.model.DoesNotExist
|
||||
elif len(self._result_cache) > 1:
|
||||
raise self.model.MultipleObjectsReturned(
|
||||
'{} objects found'.format(len(self._result_cache)))
|
||||
raise self.model.MultipleObjectsReturned('{} objects found'.format(len(self._result_cache)))
|
||||
else:
|
||||
return self[0]
|
||||
|
||||
@@ -665,7 +671,7 @@ class AbstractQuerySet(object):
|
||||
if clone._defer_fields or clone._only_fields:
|
||||
raise QueryException("QuerySet alread has only or defer fields defined")
|
||||
|
||||
#check for strange fields
|
||||
# check for strange fields
|
||||
missing_fields = [f for f in fields if f not in self.model._columns.keys()]
|
||||
if missing_fields:
|
||||
raise QueryException(
|
||||
@@ -698,7 +704,7 @@ class AbstractQuerySet(object):
|
||||
"""
|
||||
Deletes the contents of a query
|
||||
"""
|
||||
#validate where clause
|
||||
# validate where clause
|
||||
partition_key = [x for x in self.model._primary_keys.values()][0]
|
||||
if not any([c.field == partition_key.column_name for c in self._where]):
|
||||
raise QueryException("The partition key must be defined on delete queries")
|
||||
@@ -759,14 +765,14 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
"""
|
||||
def _validate_select_where(self):
|
||||
""" Checks that a filterset will not create invalid select statement """
|
||||
#check that there's either a = or IN relationship with a primary key or indexed field
|
||||
# check that there's either a = or IN relationship with a primary key or indexed field
|
||||
equal_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)]
|
||||
token_comparison = any([w for w in self._where if isinstance(w.value, Token)])
|
||||
if not any([w.primary_key or w.index for w in equal_ops]) and not token_comparison and not self._allow_filtering:
|
||||
raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field')
|
||||
|
||||
if not self._allow_filtering:
|
||||
#if the query is not on an indexed field
|
||||
# if the query is not on an indexed field
|
||||
if not any([w.index for w in equal_ops]):
|
||||
if not any([w.partition_key for w in equal_ops]) and not token_comparison:
|
||||
raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset')
|
||||
@@ -783,9 +789,9 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
|
||||
def _get_result_constructor(self):
|
||||
""" Returns a function that will be used to instantiate query results """
|
||||
if not self._values_list: # we want models
|
||||
if not self._values_list: # we want models
|
||||
return lambda rows: self.model._construct_instance(rows)
|
||||
elif self._flat_values_list: # the user has requested flattened list (1 value per row)
|
||||
elif self._flat_values_list: # the user has requested flattened list (1 value per row)
|
||||
return lambda row: row.popitem()[1]
|
||||
else:
|
||||
return lambda row: self._get_row_value_list(self._only_fields, row)
|
||||
@@ -803,7 +809,7 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
if column is None:
|
||||
raise QueryException("Can't resolve the column name: '{}'".format(colname))
|
||||
|
||||
#validate the column selection
|
||||
# validate the column selection
|
||||
if not column.primary_key:
|
||||
raise QueryException(
|
||||
"Can't order on '{}', can only order on (clustered) primary keys".format(colname))
|
||||
@@ -958,7 +964,7 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
|
||||
# we should not provide default values in this use case.
|
||||
val = col.validate(val)
|
||||
|
||||
|
||||
if val is None:
|
||||
nulled_columns.add(col_name)
|
||||
continue
|
||||
@@ -1072,11 +1078,11 @@ class DMLQuery(object):
|
||||
timestamp=self._timestamp, transactions=self._transaction)
|
||||
for name, col in self.instance._clustering_keys.items():
|
||||
null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None))
|
||||
#get defined fields and their column names
|
||||
# get defined fields and their column names
|
||||
for name, col in self.model._columns.items():
|
||||
# if clustering key is null, don't include non static columns
|
||||
if null_clustering_key and not col.static and not col.partition_key:
|
||||
continue
|
||||
continue
|
||||
if not col.is_primary_key:
|
||||
val = getattr(self.instance, name, None)
|
||||
val_mgr = self.instance._values[name]
|
||||
@@ -1088,19 +1094,24 @@ class DMLQuery(object):
|
||||
# don't update something if it hasn't changed
|
||||
if not val_mgr.changed and not isinstance(col, columns.Counter):
|
||||
continue
|
||||
|
||||
|
||||
static_changed_only = static_changed_only and col.static
|
||||
if isinstance(col, (columns.BaseContainerColumn, columns.Counter)):
|
||||
# get appropriate clause
|
||||
if isinstance(col, columns.List): klass = ListUpdateClause
|
||||
elif isinstance(col, columns.Map): klass = MapUpdateClause
|
||||
elif isinstance(col, columns.Set): klass = SetUpdateClause
|
||||
elif isinstance(col, columns.Counter): klass = CounterUpdateClause
|
||||
else: raise RuntimeError
|
||||
if isinstance(col, columns.List):
|
||||
klass = ListUpdateClause
|
||||
elif isinstance(col, columns.Map):
|
||||
klass = MapUpdateClause
|
||||
elif isinstance(col, columns.Set):
|
||||
klass = SetUpdateClause
|
||||
elif isinstance(col, columns.Counter):
|
||||
klass = CounterUpdateClause
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# do the stuff
|
||||
clause = klass(col.db_field_name, val,
|
||||
previous=val_mgr.previous_value, column=col)
|
||||
previous=val_mgr.previous_value, column=col)
|
||||
if clause.get_context_size() > 0:
|
||||
statement.add_assignment_clause(clause)
|
||||
else:
|
||||
@@ -1145,7 +1156,7 @@ class DMLQuery(object):
|
||||
static_save_only = static_save_only and col._val_is_null(getattr(self.instance, name, None))
|
||||
for name, col in self.instance._columns.items():
|
||||
if static_save_only and not col.static and not col.partition_key:
|
||||
continue
|
||||
continue
|
||||
val = getattr(self.instance, name, None)
|
||||
if col._val_is_null(val):
|
||||
if self.instance._values[name].changed:
|
||||
@@ -1171,7 +1182,9 @@ class DMLQuery(object):
|
||||
|
||||
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp)
|
||||
for name, col in self.model._primary_keys.items():
|
||||
if (not col.partition_key) and (getattr(self.instance, name) is None): continue
|
||||
if (not col.partition_key) and (getattr(self.instance, name) is None):
|
||||
continue
|
||||
|
||||
ds.add_where_clause(WhereClause(
|
||||
col.db_field_name,
|
||||
EqualsOperator(),
|
||||
|
||||
@@ -7,7 +7,8 @@ from cassandra.cqlengine.functions import QueryValue
|
||||
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator
|
||||
|
||||
|
||||
class StatementException(Exception): pass
|
||||
class StatementException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnicodeMixin(object):
|
||||
@@ -16,6 +17,7 @@ class UnicodeMixin(object):
|
||||
else:
|
||||
__str__ = lambda x: six.text_type(x).encode('utf-8')
|
||||
|
||||
|
||||
class ValueQuoter(UnicodeMixin):
|
||||
|
||||
def __init__(self, value):
|
||||
@@ -28,7 +30,7 @@ class ValueQuoter(UnicodeMixin):
|
||||
elif isinstance(self.value, (list, tuple)):
|
||||
return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']'
|
||||
elif isinstance(self.value, dict):
|
||||
return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k,v in self.value.items()]) + '}'
|
||||
return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}'
|
||||
elif isinstance(self.value, set):
|
||||
return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}'
|
||||
return cql_quote(self.value)
|
||||
@@ -215,7 +217,8 @@ class SetUpdateClause(ContainerUpdateClause):
|
||||
self._analyzed = True
|
||||
|
||||
def get_context_size(self):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
if (self.previous is None and
|
||||
not self._assignments and
|
||||
self._additions is None and
|
||||
@@ -224,7 +227,8 @@ class SetUpdateClause(ContainerUpdateClause):
|
||||
return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals))
|
||||
|
||||
def update_context(self, ctx):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
ctx_id = self.context_id
|
||||
if (self.previous is None and
|
||||
self._assignments is None and
|
||||
@@ -250,7 +254,8 @@ class ListUpdateClause(ContainerUpdateClause):
|
||||
self._prepend = None
|
||||
|
||||
def __unicode__(self):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
qs = []
|
||||
ctx_id = self.context_id
|
||||
if self._assignments is not None:
|
||||
@@ -267,11 +272,13 @@ class ListUpdateClause(ContainerUpdateClause):
|
||||
return ', '.join(qs)
|
||||
|
||||
def get_context_size(self):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend))
|
||||
|
||||
def update_context(self, ctx):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
ctx_id = self.context_id
|
||||
if self._assignments is not None:
|
||||
ctx[str(ctx_id)] = self._to_database(self._assignments)
|
||||
@@ -313,13 +320,13 @@ class ListUpdateClause(ContainerUpdateClause):
|
||||
else:
|
||||
|
||||
# the max start idx we want to compare
|
||||
search_space = len(self.value) - max(0, len(self.previous)-1)
|
||||
search_space = len(self.value) - max(0, len(self.previous) - 1)
|
||||
|
||||
# the size of the sub lists we want to look at
|
||||
search_size = len(self.previous)
|
||||
|
||||
for i in range(search_space):
|
||||
#slice boundary
|
||||
# slice boundary
|
||||
j = i + search_size
|
||||
sub = self.value[i:j]
|
||||
idx_cmp = lambda idx: self.previous[idx] == sub[idx]
|
||||
@@ -354,13 +361,15 @@ class MapUpdateClause(ContainerUpdateClause):
|
||||
self._analyzed = True
|
||||
|
||||
def get_context_size(self):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
if self.previous is None and not self._updates:
|
||||
return 1
|
||||
return len(self._updates or []) * 2
|
||||
|
||||
def update_context(self, ctx):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
ctx_id = self.context_id
|
||||
if self.previous is None and not self._updates:
|
||||
ctx[str(ctx_id)] = {}
|
||||
@@ -372,7 +381,8 @@ class MapUpdateClause(ContainerUpdateClause):
|
||||
ctx_id += 2
|
||||
|
||||
def __unicode__(self):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
qs = []
|
||||
|
||||
ctx_id = self.context_id
|
||||
@@ -439,16 +449,19 @@ class MapDeleteClause(BaseDeleteClause):
|
||||
self._analyzed = True
|
||||
|
||||
def update_context(self, ctx):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
for idx, key in enumerate(self._removals):
|
||||
ctx[str(self.context_id + idx)] = key
|
||||
|
||||
def get_context_size(self):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
return len(self._removals)
|
||||
|
||||
def __unicode__(self):
|
||||
if not self._analyzed: self._analyze()
|
||||
if not self._analyzed:
|
||||
self._analyze()
|
||||
return ', '.join(['"{}"[%({})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))])
|
||||
|
||||
|
||||
@@ -521,7 +534,6 @@ class BaseCQLStatement(UnicodeMixin):
|
||||
def __unicode__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return self.__unicode__()
|
||||
|
||||
@@ -645,13 +657,12 @@ class InsertStatement(AssignmentStatement):
|
||||
ttl=None,
|
||||
timestamp=None,
|
||||
if_not_exists=False):
|
||||
super(InsertStatement, self).__init__(
|
||||
table,
|
||||
assignments=assignments,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
ttl=ttl,
|
||||
timestamp=timestamp)
|
||||
super(InsertStatement, self).__init__(table,
|
||||
assignments=assignments,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
ttl=ttl,
|
||||
timestamp=timestamp)
|
||||
|
||||
self.if_not_exists = if_not_exists
|
||||
|
||||
@@ -813,4 +824,3 @@ class DeleteStatement(BaseCQLStatement):
|
||||
qs += [self._where]
|
||||
|
||||
return ' '.join(qs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user