PEP-8 cleanup

This commit is contained in:
Adam Holmberg
2015-02-10 11:47:22 -06:00
parent f078e23935
commit 62c25d60e9
10 changed files with 282 additions and 201 deletions

View File

@@ -1,4 +1,3 @@
#column field types
from copy import deepcopy, copy
from datetime import datetime
from datetime import date
@@ -86,7 +85,7 @@ class ValueQuoter(object):
class Column(object):
#the cassandra type this column maps to
# the cassandra type this column maps to
db_type = None
value_manager = BaseValueManager
@@ -161,13 +160,13 @@ class Column(object):
self.required = required
self.clustering_order = clustering_order
self.polymorphic_key = polymorphic_key
#the column name in the model definition
# the column name in the model definition
self.column_name = None
self.static = static
self.value = None
#keep track of instantiation order
# keep track of instantiation order
self.position = Column.instance_counter
Column.instance_counter += 1
@@ -266,17 +265,18 @@ class Blob(Column):
return bytearray(val)
def to_python(self, value):
#return value[2:].decode('hex')
return value
Bytes = Blob
class Ascii(Column):
"""
Stores a US-ASCII character string
"""
db_type = 'ascii'
class Inet(Column):
"""
Stores an IP address in IPv4 or IPv6 format
@@ -309,7 +309,8 @@ class Text(Column):
def validate(self, value):
value = super(Text, self).validate(value)
if value is None: return
if value is None:
return
if not isinstance(value, (six.string_types, bytearray)) and value is not None:
raise ValidationError('{} {} is not a string'.format(self.column_name, type(value)))
if self.max_length:
@@ -330,7 +331,8 @@ class Integer(Column):
def validate(self, value):
val = super(Integer, self).validate(value)
if val is None: return
if val is None:
return
try:
return int(val)
except (TypeError, ValueError):
@@ -409,7 +411,8 @@ class DateTime(Column):
db_type = 'timestamp'
def to_python(self, value):
if value is None: return
if value is None:
return
if isinstance(value, datetime):
return value
elif isinstance(value, date):
@@ -421,7 +424,8 @@ class DateTime(Column):
def to_database(self, value):
value = super(DateTime, self).to_database(value)
if value is None: return
if value is None:
return
if not isinstance(value, datetime):
if isinstance(value, date):
value = datetime(value.year, value.month, value.day)
@@ -443,7 +447,8 @@ class Date(Column):
db_type = 'timestamp'
def to_python(self, value):
if value is None: return
if value is None:
return
if isinstance(value, datetime):
return value.date()
elif isinstance(value, date):
@@ -455,7 +460,8 @@ class Date(Column):
def to_database(self, value):
value = super(Date, self).to_database(value)
if value is None: return
if value is None:
return
if isinstance(value, datetime):
value = value.date()
if not isinstance(value, date):
@@ -474,9 +480,11 @@ class UUID(Column):
def validate(self, value):
val = super(UUID, self).validate(value)
if val is None: return
if val is None:
return
from uuid import UUID as _UUID
if isinstance(val, _UUID): return val
if isinstance(val, _UUID):
return val
if isinstance(val, six.string_types) and self.re_uuid.match(val):
return _UUID(val)
raise ValidationError("{} {} is not a valid uuid".format(self.column_name, value))
@@ -510,7 +518,7 @@ class TimeUUID(UUID):
epoch = datetime(1970, 1, 1, tzinfo=dt.tzinfo)
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
timestamp = (dt - epoch).total_seconds() - offset
timestamp = (dt - epoch).total_seconds() - offset
node = None
clock_seq = None
@@ -529,7 +537,7 @@ class TimeUUID(UUID):
if node is None:
node = getnode()
return pyUUID(fields=(time_low, time_mid, time_hi_version,
clock_seq_hi_variant, clock_seq_low, node), version=1)
clock_seq_hi_variant, clock_seq_low, node), version=1)
class Boolean(Column):
@@ -563,7 +571,8 @@ class Float(Column):
def validate(self, value):
value = super(Float, self).validate(value)
if value is None: return
if value is None:
return
try:
return float(value)
except (TypeError, ValueError):
@@ -586,7 +595,8 @@ class Decimal(Column):
from decimal import Decimal as _Decimal
from decimal import InvalidOperation
val = super(Decimal, self).validate(value)
if val is None: return
if val is None:
return
try:
return _Decimal(val)
except InvalidOperation:
@@ -679,7 +689,8 @@ class Set(BaseContainerColumn):
def validate(self, value):
val = super(Set, self).validate(value)
if val is None: return
if val is None:
return
types = (set,) if self.strict else (set, list, tuple)
if not isinstance(val, types):
if self.strict:
@@ -693,18 +704,19 @@ class Set(BaseContainerColumn):
return {self.value_col.validate(v) for v in val}
def to_python(self, value):
if value is None: return set()
if value is None:
return set()
return {self.value_col.to_python(v) for v in value}
def to_database(self, value):
if value is None: return None
if value is None:
return None
if isinstance(value, self.Quoter): return value
if isinstance(value, self.Quoter):
return value
return self.Quoter({self.value_col.to_database(v) for v in value})
class List(BaseContainerColumn):
"""
Stores a list of ordered values
@@ -727,7 +739,8 @@ class List(BaseContainerColumn):
def validate(self, value):
val = super(List, self).validate(value)
if val is None: return
if val is None:
return
if not isinstance(val, (set, list, tuple)):
raise ValidationError('{} {} is not a list object'.format(self.column_name, val))
if None in val:
@@ -735,12 +748,15 @@ class List(BaseContainerColumn):
return [self.value_col.validate(v) for v in val]
def to_python(self, value):
if value is None: return []
if value is None:
return []
return [self.value_col.to_python(v) for v in value]
def to_database(self, value):
if value is None: return None
if isinstance(value, self.Quoter): return value
if value is None:
return None
if isinstance(value, self.Quoter):
return value
return self.Quoter([self.value_col.to_database(v) for v in value])
@@ -757,7 +773,7 @@ class Map(BaseContainerColumn):
def __str__(self):
cq = cql_quote
return '{' + ', '.join([cq(k) + ':' + cq(v) for k,v in self.value.items()]) + '}'
return '{' + ', '.join([cq(k) + ':' + cq(v) for k, v in self.value.items()]) + '}'
def get(self, key):
return self.value.get(key)
@@ -802,22 +818,24 @@ class Map(BaseContainerColumn):
def validate(self, value):
val = super(Map, self).validate(value)
if val is None: return
if val is None:
return
if not isinstance(val, dict):
raise ValidationError('{} {} is not a dict object'.format(self.column_name, val))
return {self.key_col.validate(k):self.value_col.validate(v) for k,v in val.items()}
return {self.key_col.validate(k): self.value_col.validate(v) for k, v in val.items()}
def to_python(self, value):
if value is None:
return {}
if value is not None:
return {self.key_col.to_python(k): self.value_col.to_python(v) for k,v in value.items()}
return {self.key_col.to_python(k): self.value_col.to_python(v) for k, v in value.items()}
def to_database(self, value):
if value is None: return None
if isinstance(value, self.Quoter): return value
return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()})
if value is None:
return None
if isinstance(value, self.Quoter):
return value
return self.Quoter({self.key_col.to_database(k): self.value_col.to_database(v) for k, v in value.items()})
class _PartitionKeysToken(Column):
@@ -842,4 +860,3 @@ class _PartitionKeysToken(Column):
def get_cql(self):
return "token({})".format(", ".join(c.cql for c in self.partition_columns))

View File

@@ -1,6 +1,6 @@
from collections import namedtuple
import six
import logging
import six
from cassandra import ConsistencyLevel
from cassandra.cluster import Cluster, _NOT_SET, NoHostAvailable
@@ -9,11 +9,13 @@ from cassandra.cqlengine.exceptions import CQLEngineException, UndefinedKeyspace
from cassandra.cqlengine.statements import BaseCQLStatement
LOG = logging.getLogger('cqlengine.cql')
log = logging.getLogger(__name__)
NOT_SET = _NOT_SET # required for passing timeout to Session.execute
class CQLConnectionError(CQLEngineException): pass
class CQLConnectionError(CQLEngineException):
pass
Host = namedtuple('Host', ['name', 'port'])
@@ -22,6 +24,7 @@ session = None
lazy_connect_args = None
default_consistency_level = None
def setup(
hosts,
default_keyspace,
@@ -94,21 +97,24 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET):
elif isinstance(query, six.string_types):
query = SimpleStatement(query, consistency_level=consistency_level)
LOG.info(query.query_string)
log.debug(query.query_string)
params = params or {}
result = session.execute(query, params, timeout=timeout)
return result
def get_session():
handle_lazy_connect()
return session
def get_cluster():
handle_lazy_connect()
return cluster
def handle_lazy_connect():
global lazy_connect_args
if lazy_connect_args:

View File

@@ -1,8 +1,22 @@
#cqlengine exceptions
class CQLEngineException(Exception): pass
class ModelException(CQLEngineException): pass
class ValidationError(CQLEngineException): pass
class CQLEngineException(Exception):
pass
class UndefinedKeyspaceException(CQLEngineException): pass
class LWTException(CQLEngineException): pass
class IfNotExistsWithCounterColumn(CQLEngineException): pass
class ModelException(CQLEngineException):
pass
class ValidationError(CQLEngineException):
pass
class UndefinedKeyspaceException(CQLEngineException):
pass
class LWTException(CQLEngineException):
pass
class IfNotExistsWithCounterColumn(CQLEngineException):
pass

View File

@@ -5,12 +5,14 @@ import sys
from cassandra.cqlengine.exceptions import ValidationError
# move to central spot
class UnicodeMixin(object):
if sys.version_info > (3, 0):
__str__ = lambda x: x.__unicode__()
else:
__str__ = lambda x: six.text_type(x).encode('utf-8')
class QueryValue(UnicodeMixin):
"""
Base class for query filter values. Subclasses of these classes can
@@ -42,6 +44,8 @@ class BaseQueryFunction(QueryValue):
be passed into .filter() and will be translated into CQL functions in
the resulting query
"""
pass
class MinTimeUUID(BaseQueryFunction):
"""
@@ -123,4 +127,3 @@ class Token(BaseQueryFunction):
def update_context(self, ctx):
for i, (col, val) in enumerate(zip(self._columns, self.value)):
ctx[str(self.context_id + i)] = col.to_database(val)

View File

@@ -12,11 +12,12 @@ from cassandra.cqlengine.named import NamedTable
Field = namedtuple('Field', ['name', 'type'])
logger = logging.getLogger(__name__)
log = logging.getLogger(__name__)
# system keyspaces
schema_columnfamilies = NamedTable('system', 'schema_columnfamilies')
def create_keyspace(name, strategy_class, replication_factor, durable_writes=True, **replication_values):
"""
*Deprecated - this will likely be repaced with something specialized per replication strategy.*
@@ -38,10 +39,10 @@ def create_keyspace(name, strategy_class, replication_factor, durable_writes=Tru
cluster = get_cluster()
if name not in cluster.metadata.keyspaces:
#try the 1.2 method
# try the 1.2 method
replication_map = {
'class': strategy_class,
'replication_factor':replication_factor
'replication_factor': replication_factor
}
replication_map.update(replication_values)
if strategy_class.lower() != 'simplestrategy':
@@ -76,6 +77,7 @@ def delete_keyspace(name):
if name in cluster.metadata.keyspaces:
execute("DROP KEYSPACE {}".format(name))
def sync_table(model):
"""
Inspects the model and creates / updates the corresponding table and columns.
@@ -95,8 +97,7 @@ def sync_table(model):
if model.__abstract__:
raise CQLEngineException("cannot create table from abstract model")
#construct query string
# construct query string
cf_name = model.column_family_name()
raw_cf_name = model.column_family_name(include_keyspace=False)
@@ -107,7 +108,7 @@ def sync_table(model):
keyspace = cluster.metadata.keyspaces[ks_name]
tables = keyspace.tables
#check for an existing column family
# check for an existing column family
if raw_cf_name not in tables:
qs = get_create_table(model)
@@ -123,20 +124,21 @@ def sync_table(model):
fields = get_fields(model)
field_names = [x.name for x in fields]
for name, col in model._columns.items():
if col.primary_key or col.partition_key: continue # we can't mess with the PK
if col.db_field_name in field_names: continue # skip columns already defined
if col.primary_key or col.partition_key:
continue # we can't mess with the PK
if col.db_field_name in field_names:
continue # skip columns already defined
# add missing column using the column def
query = "ALTER TABLE {} add {}".format(cf_name, col.get_column_def())
logger.debug(query)
log.debug(query)
execute(query)
update_compaction(model)
table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name]
indexes = [c for n,c in model._columns.items() if c.index]
indexes = [c for n, c in model._columns.items() if c.index]
for column in indexes:
if table.columns[column.db_field_name].index:
@@ -148,20 +150,23 @@ def sync_table(model):
qs = ' '.join(qs)
execute(qs)
def get_create_table(model):
cf_name = model.column_family_name()
qs = ['CREATE TABLE {}'.format(cf_name)]
#add column types
pkeys = [] # primary keys
ckeys = [] # clustering keys
qtypes = [] # field types
# add column types
pkeys = [] # primary keys
ckeys = [] # clustering keys
qtypes = [] # field types
def add_column(col):
s = col.get_column_def()
if col.primary_key:
keys = (pkeys if col.partition_key else ckeys)
keys.append('"{}"'.format(col.db_field_name))
qtypes.append(s)
for name, col in model._columns.items():
add_column(col)
@@ -172,9 +177,9 @@ def get_create_table(model):
with_qs = []
table_properties = ['bloom_filter_fp_chance', 'caching', 'comment',
'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds',
'index_interval', 'memtable_flush_period_in_ms', 'populate_io_cache_on_flush',
'read_repair_chance', 'replicate_on_write']
'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds',
'index_interval', 'memtable_flush_period_in_ms', 'populate_io_cache_on_flush',
'read_repair_chance', 'replicate_on_write']
for prop_name in table_properties:
prop_value = getattr(model, '__{}__'.format(prop_name), None)
if prop_value is not None:
@@ -211,9 +216,9 @@ def get_compaction_options(model):
if not model.__compaction__:
return {}
result = {'class':model.__compaction__}
result = {'class': model.__compaction__}
def setter(key, limited_to_strategy = None):
def setter(key, limited_to_strategy=None):
"""
sets key in result, checking if the key is limited to either SizeTiered or Leveled
:param key: one of the compaction options, like "bucket_high"
@@ -270,6 +275,7 @@ def get_table_settings(model):
table = cluster.metadata.keyspaces[ks].tables[table]
return table
def update_compaction(model):
"""Updates the compaction options for the given model if necessary.
@@ -279,7 +285,7 @@ def update_compaction(model):
`False` otherwise.
:rtype: bool
"""
logger.debug("Checking %s for compaction differences", model)
log.debug("Checking %s for compaction differences", model)
table = get_table_settings(model)
existing_options = table.options.copy()
@@ -311,7 +317,7 @@ def update_compaction(model):
options = json.dumps(options).replace('"', "'")
cf_name = model.column_family_name()
query = "ALTER TABLE {} with compaction = {}".format(cf_name, options)
logger.debug(query)
log.debug(query)
execute(query)
return True
@@ -339,6 +345,3 @@ def drop_table(model):
execute('drop table {};'.format(model.column_family_name(include_keyspace=True)))
except KeyError:
pass

View File

@@ -1,4 +1,5 @@
import re
import six
from cassandra.cqlengine import columns
from cassandra.cqlengine.exceptions import ModelException, CQLEngineException, ValidationError
@@ -7,16 +8,20 @@ from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
from cassandra.util import OrderedDict
class ModelDefinitionException(ModelException): pass
class ModelDefinitionException(ModelException):
pass
class PolyMorphicModelException(ModelException): pass
class PolyMorphicModelException(ModelException):
pass
DEFAULT_KEYSPACE = None
class UndefinedKeyspaceWarning(Warning):
pass
DEFAULT_KEYSPACE = None
class hybrid_classmethod(object):
"""
@@ -106,7 +111,7 @@ class TTLDescriptor(object):
"""
def __get__(self, instance, model):
if instance:
#instance = copy.deepcopy(instance)
# instance = copy.deepcopy(instance)
# instance method
def ttl_setter(ts):
instance._ttl = ts
@@ -124,6 +129,7 @@ class TTLDescriptor(object):
def __call__(self, *args, **kwargs):
raise NotImplementedError
class TimestampDescriptor(object):
"""
returns a query set descriptor with a timestamp specified
@@ -138,10 +144,10 @@ class TimestampDescriptor(object):
return model.objects.timestamp
def __call__(self, *args, **kwargs):
raise NotImplementedError
class IfNotExistsDescriptor(object):
"""
return a query set descriptor with a if_not_exists flag specified
@@ -159,13 +165,14 @@ class IfNotExistsDescriptor(object):
def __call__(self, *args, **kwargs):
raise NotImplementedError
class ConsistencyDescriptor(object):
"""
returns a query set descriptor if called on Class, instance if it was an instance call
"""
def __get__(self, instance, model):
if instance:
#instance = copy.deepcopy(instance)
# instance = copy.deepcopy(instance)
def consistency_setter(consistency):
instance.__consistency__ = consistency
return instance
@@ -229,7 +236,7 @@ class ColumnDescriptor(object):
"""
try:
return instance._values[self.column.column_name].getval()
except AttributeError as e:
except AttributeError:
return self.query_evaluator
def __set__(self, instance, value):
@@ -258,9 +265,11 @@ class BaseModel(object):
The base model class, don't inherit from this, inherit from Model, defined below
"""
class DoesNotExist(_DoesNotExist): pass
class DoesNotExist(_DoesNotExist):
pass
class MultipleObjectsReturned(_MultipleObjectsReturned): pass
class MultipleObjectsReturned(_MultipleObjectsReturned):
pass
objects = QuerySetDescriptor()
ttl = TTLDescriptor()
@@ -269,7 +278,7 @@ class BaseModel(object):
# custom timestamps, see USING TIMESTAMP X
timestamp = TimestampDescriptor()
if_not_exists = IfNotExistsDescriptor()
# _len is lazily created by __len__
@@ -302,8 +311,7 @@ class BaseModel(object):
__queryset__ = ModelQuerySet
__dmlquery__ = DMLQuery
__consistency__ = None # can be set per query
__consistency__ = None # can be set per query
# Additional table properties
__bloom_filter_fp_chance__ = None
@@ -318,9 +326,9 @@ class BaseModel(object):
__read_repair_chance__ = None
__replicate_on_write__ = None
_timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP)
_timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP)
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
def __init__(self, **values):
self._values = {}
@@ -343,21 +351,19 @@ class BaseModel(object):
self._batch = None
self._timeout = NOT_SET
def __repr__(self):
"""
Pretty printing of models by their primary key
"""
return '{} <{}>'.format(self.__class__.__name__,
', '.join(('{}={}'.format(k, getattr(self, k)) for k,v in six.iteritems(self._primary_keys)))
', '.join(('{}={}'.format(k, getattr(self, k)) for k, v in six.iteritems(self._primary_keys)))
)
@classmethod
def _discover_polymorphic_submodels(cls):
if not cls._is_polymorphic_base:
raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes')
def _discover(klass):
if not klass._is_polymorphic_base and klass.__polymorphic_key__ is not None:
cls._polymorphic_map[klass.__polymorphic_key__] = klass
@@ -381,7 +387,7 @@ class BaseModel(object):
# and translate that into our local fields
# the db_map is a db_field -> model field map
items = values.items()
field_dict = dict([(cls._db_map.get(k, k),v) for k,v in items])
field_dict = dict([(cls._db_map.get(k, k), v) for k, v in items])
if cls._is_polymorphic:
poly_key = field_dict.get(cls._polymorphic_column_name)
@@ -421,8 +427,9 @@ class BaseModel(object):
:return:
"""
if not self._is_persisted: return False
pks = self._primary_keys.keys()
if not self._is_persisted:
return False
return all([not self._values[k].changed for k in self._primary_keys])
@classmethod
@@ -481,11 +488,14 @@ class BaseModel(object):
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
cf_name += ccase(cls.__name__)
#trim to less than 48 characters or cassandra will complain
# trim to less than 48 characters or cassandra will complain
cf_name = cf_name[-48:]
cf_name = cf_name.lower()
cf_name = re.sub(r'^_+', '', cf_name)
if not include_keyspace: return cf_name
if not include_keyspace:
return cf_name
return '{}.{}'.format(cls._get_keyspace(), cf_name)
def validate(self):
@@ -499,8 +509,7 @@ class BaseModel(object):
val = col.validate(v)
setattr(self, name, val)
### Let an instance be used like a dict of its columns keys/values
# Let an instance be used like a dict of its columns keys/values
def __iter__(self):
""" Iterate over column ids. """
for column_id in self._columns.keys():
@@ -620,7 +629,6 @@ class BaseModel(object):
else:
setattr(self, self._polymorphic_column_name, self.__polymorphic_key__)
is_new = self.pk is None
self.validate()
self.__dmlquery__(self.__class__, self,
batch=self._batch,
@@ -631,7 +639,7 @@ class BaseModel(object):
transaction=self._transaction,
timeout=self._timeout).save()
#reset the value managers
# reset the value managers
for v in self._values.values():
v.reset_previous_value()
self._is_persisted = True
@@ -680,7 +688,7 @@ class BaseModel(object):
transaction=self._transaction,
timeout=self._timeout).update()
#reset the value managers
# reset the value managers
for v in self._values.values():
v.reset_previous_value()
self._is_persisted = True
@@ -704,7 +712,7 @@ class BaseModel(object):
"""
Returns a list of the columns that have been updated since instantiation or save
"""
return [k for k,v in self._values.items() if v.changed]
return [k for k, v in self._values.items() if v.changed]
@classmethod
def _class_batch(cls, batch):
@@ -715,29 +723,28 @@ class BaseModel(object):
self._batch = batch
return self
batch = hybrid_classmethod(_class_batch, _inst_batch)
class ModelMetaClass(type):
def __new__(cls, name, bases, attrs):
#move column definitions into columns dict
#and set default column names
# move column definitions into columns dict
# and set default column names
column_dict = OrderedDict()
primary_keys = OrderedDict()
pk_name = None
#get inherited properties
# get inherited properties
inherited_columns = OrderedDict()
for base in bases:
for k,v in getattr(base, '_defined_columns', {}).items():
inherited_columns.setdefault(k,v)
for k, v in getattr(base, '_defined_columns', {}).items():
inherited_columns.setdefault(k, v)
#short circuit __abstract__ inheritance
# short circuit __abstract__ inheritance
is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False)
#short circuit __polymorphic_key__ inheritance
# short circuit __polymorphic_key__ inheritance
attrs['__polymorphic_key__'] = attrs.get('__polymorphic_key__', None)
def _transform_column(col_name, col_obj):
@@ -745,11 +752,10 @@ class ModelMetaClass(type):
if col_obj.primary_key:
primary_keys[col_name] = col_obj
col_obj.set_column_name(col_name)
#set properties
# set properties
attrs[col_name] = ColumnDescriptor(col_obj)
column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.Column)]
#column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position))
column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)]
column_definitions = sorted(column_definitions, key=lambda x: x[1].position)
is_polymorphic_base = any([c[1].polymorphic_key for c in column_definitions])
@@ -780,7 +786,7 @@ class ModelMetaClass(type):
defined_columns = OrderedDict(column_definitions)
# check for primary key
if not is_abstract and not any([v.primary_key for k,v in column_definitions]):
if not is_abstract and not any([v.primary_key for k, v in column_definitions]):
raise ModelDefinitionException("At least 1 primary key is required.")
counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)]
@@ -790,7 +796,7 @@ class ModelMetaClass(type):
has_partition_keys = any(v.partition_key for (k, v) in column_definitions)
#transform column definitions
# transform column definitions
for k, v in column_definitions:
# don't allow a column with the same name as a built-in attribute or method
if k in BaseModel.__dict__:
@@ -810,7 +816,7 @@ class ModelMetaClass(type):
partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key)
clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key)
#setup partition key shortcut
# setup partition key shortcut
if len(partition_keys) == 0:
if not is_abstract:
raise ModelException("at least one partition key must be defined")
@@ -835,12 +841,12 @@ class ModelMetaClass(type):
raise ModelException("invalid clustering order {} for column {}".format(repr(v.clustering_order), v.db_field_name))
col_names.add(v.db_field_name)
#create db_name -> model name map for loading
# create db_name -> model name map for loading
db_map = {}
for field_name, col in column_dict.items():
db_map[col.db_field_name] = field_name
#add management members to the class
# add management members to the class
attrs['_columns'] = column_dict
attrs['_primary_keys'] = primary_keys
attrs['_defined_columns'] = defined_columns
@@ -862,29 +868,31 @@ class ModelMetaClass(type):
attrs['_polymorphic_column_name'] = polymorphic_column_name
attrs['_polymorphic_map'] = {} if is_polymorphic_base else None
#setup class exceptions
# setup class exceptions
DoesNotExistBase = None
for base in bases:
DoesNotExistBase = getattr(base, 'DoesNotExist', None)
if DoesNotExistBase is not None: break
if DoesNotExistBase is not None:
break
DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist)
attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {})
MultipleObjectsReturnedBase = None
for base in bases:
MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None)
if MultipleObjectsReturnedBase is not None: break
if MultipleObjectsReturnedBase is not None:
break
MultipleObjectsReturnedBase = DoesNotExistBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned)
attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {})
#create the class and add a QuerySet to it
# create the class and add a QuerySet to it
klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs)
return klass
import six
@six.add_metaclass(ModelMetaClass)
class Model(BaseModel):
__abstract__ = True
@@ -906,7 +914,7 @@ class Model(BaseModel):
__default_ttl__ = None
"""
*Optional* The default ttl used by this model.
This can be overridden by using the :meth:`~.ttl` method.
"""

View File

@@ -3,6 +3,7 @@ from cassandra.cqlengine.query import AbstractQueryableColumn, SimpleQuerySet
from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
class QuerySetDescriptor(object):
"""
returns a fresh queryset for the given model
@@ -63,8 +64,11 @@ class NamedTable(object):
objects = QuerySetDescriptor()
class DoesNotExist(_DoesNotExist): pass
class MultipleObjectsReturned(_MultipleObjectsReturned): pass
class DoesNotExist(_DoesNotExist):
pass
class MultipleObjectsReturned(_MultipleObjectsReturned):
pass
def __init__(self, keyspace, name):
self.keyspace = keyspace
@@ -118,4 +122,3 @@ class NamedKeyspace(object):
name that belongs to this keyspace
"""
return NamedTable(self.name, name)

View File

@@ -2,7 +2,9 @@ import six
import sys
class QueryOperatorException(Exception): pass
class QueryOperatorException(Exception):
pass
# move to central spot
class UnicodeMixin(object):
@@ -31,12 +33,14 @@ class BaseQueryOperator(UnicodeMixin):
raise QueryOperatorException("get_operator can only be called from a BaseQueryOperator subclass")
if not hasattr(cls, 'opmap'):
cls.opmap = {}
def _recurse(klass):
if klass.symbol:
cls.opmap[klass.symbol.upper()] = klass
for subklass in klass.__subclasses__():
_recurse(subklass)
pass
_recurse(cls)
try:
return cls.opmap[symbol.upper()]

View File

@@ -1,6 +1,7 @@
import copy
from datetime import datetime, timedelta
import time
import six
from cassandra.cqlengine import columns
from cassandra.cqlengine.connection import execute, NOT_SET
@@ -17,11 +18,16 @@ from cassandra.cqlengine.statements import (WhereClause, SelectStatement, Delete
TransactionClause)
class QueryException(CQLEngineException): pass
class DoesNotExist(QueryException): pass
class MultipleObjectsReturned(QueryException): pass
class QueryException(CQLEngineException):
pass
import six
class DoesNotExist(QueryException):
pass
class MultipleObjectsReturned(QueryException):
pass
def check_applied(result):
@@ -30,7 +36,7 @@ def check_applied(result):
if that value is false, it means our light-weight transaction didn't
applied to database.
"""
if result and '[applied]' in result[0] and result[0]['[applied]'] == False:
if result and '[applied]' in result[0] and not result[0]['[applied]']:
raise LWTException('')
@@ -77,8 +83,8 @@ class AbstractQueryableColumn(UnicodeMixin):
class BatchType(object):
Unlogged = 'UNLOGGED'
Counter = 'COUNTER'
Unlogged = 'UNLOGGED'
Counter = 'COUNTER'
class BatchQuery(object):
@@ -194,8 +200,9 @@ class BatchQuery(object):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
#don't execute if there was an exception by default
if exc_type is not None and not self._execute_on_exception: return
# don't execute if there was an exception by default
if exc_type is not None and not self._execute_on_exception:
return
self.execute()
@@ -205,29 +212,29 @@ class AbstractQuerySet(object):
super(AbstractQuerySet, self).__init__()
self.model = model
#Where clause filters
# Where clause filters
self._where = []
# Transaction clause filters
self._transaction = []
#ordering arguments
# ordering arguments
self._order = []
self._allow_filtering = False
#CQL has a default limit of 10000, it's defined here
#because explicit is better than implicit
# CQL has a default limit of 10000, it's defined here
# because explicit is better than implicit
self._limit = 10000
#see the defer and only methods
# see the defer and only methods
self._defer_fields = []
self._only_fields = []
self._values_list = False
self._flat_values_list = False
#results cache
# results cache
self._con = None
self._cur = None
self._result_cache = None
@@ -265,7 +272,7 @@ class AbstractQuerySet(object):
def __deepcopy__(self, memo):
clone = self.__class__(self.model)
for k, v in self.__dict__.items():
if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these
if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these
clone.__dict__[k] = None
elif k == '_batch':
# we need to keep the same batch instance across
@@ -284,7 +291,7 @@ class AbstractQuerySet(object):
self._execute_query()
return len(self._result_cache)
#----query generation / execution----
# ----query generation / execution----
def _select_fields(self):
""" returns the fields to select """
@@ -308,7 +315,7 @@ class AbstractQuerySet(object):
allow_filtering=self._allow_filtering
)
#----Reads------
# ----Reads------
def _execute_query(self):
if self._batch:
@@ -330,7 +337,7 @@ class AbstractQuerySet(object):
self._result_idx += 1
self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx])
#return the connection to the connection pool if we have all objects
# return the connection to the connection pool if we have all objects
if self._result_cache and self._result_idx == (len(self._result_cache) - 1):
self._con = None
self._cur = None
@@ -350,7 +357,7 @@ class AbstractQuerySet(object):
num_results = len(self._result_cache)
if isinstance(s, slice):
#calculate the amount of results that need to be loaded
# calculate the amount of results that need to be loaded
end = num_results if s.step is None else s.step
if end < 0:
end += num_results
@@ -359,11 +366,12 @@ class AbstractQuerySet(object):
self._fill_result_cache_to_idx(end)
return self._result_cache[s.start:s.stop:s.step]
else:
#return the object at this index
# return the object at this index
s = int(s)
#handle negative indexing
if s < 0: s += num_results
# handle negative indexing
if s < 0:
s += num_results
if s >= num_results:
raise IndexError
@@ -453,7 +461,6 @@ class AbstractQuerySet(object):
if not isinstance(val, Token):
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
column = columns._PartitionKeysToken(self.model)
quote_field = False
else:
raise QueryException("Can't resolve column name: '{}'".format(col_name))
@@ -484,7 +491,7 @@ class AbstractQuerySet(object):
Returns a QuerySet filtered on the keyword arguments
"""
#add arguments to the where clause filters
# add arguments to the where clause filters
if len([x for x in kwargs.values() if x is None]):
raise CQLEngineException("None values on filter are not allowed")
@@ -497,7 +504,7 @@ class AbstractQuerySet(object):
for arg, val in kwargs.items():
col_name, col_op = self._parse_filter_arg(arg)
quote_field = True
#resolve column and operator
# resolve column and operator
try:
column = self.model._get_column(col_name)
except KeyError:
@@ -519,7 +526,7 @@ class AbstractQuerySet(object):
len(val.value), len(partition_columns)))
val.set_columns(partition_columns)
#get query operator, or use equals if not supplied
# get query operator, or use equals if not supplied
operator_class = BaseWhereOperator.get_operator(col_op or 'EQ')
operator = operator_class()
@@ -559,8 +566,7 @@ class AbstractQuerySet(object):
if len(self._result_cache) == 0:
raise self.model.DoesNotExist
elif len(self._result_cache) > 1:
raise self.model.MultipleObjectsReturned(
'{} objects found'.format(len(self._result_cache)))
raise self.model.MultipleObjectsReturned('{} objects found'.format(len(self._result_cache)))
else:
return self[0]
@@ -665,7 +671,7 @@ class AbstractQuerySet(object):
if clone._defer_fields or clone._only_fields:
raise QueryException("QuerySet alread has only or defer fields defined")
#check for strange fields
# check for strange fields
missing_fields = [f for f in fields if f not in self.model._columns.keys()]
if missing_fields:
raise QueryException(
@@ -698,7 +704,7 @@ class AbstractQuerySet(object):
"""
Deletes the contents of a query
"""
#validate where clause
# validate where clause
partition_key = [x for x in self.model._primary_keys.values()][0]
if not any([c.field == partition_key.column_name for c in self._where]):
raise QueryException("The partition key must be defined on delete queries")
@@ -759,14 +765,14 @@ class ModelQuerySet(AbstractQuerySet):
"""
def _validate_select_where(self):
""" Checks that a filterset will not create invalid select statement """
#check that there's either a = or IN relationship with a primary key or indexed field
# check that there's either a = or IN relationship with a primary key or indexed field
equal_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)]
token_comparison = any([w for w in self._where if isinstance(w.value, Token)])
if not any([w.primary_key or w.index for w in equal_ops]) and not token_comparison and not self._allow_filtering:
raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field')
if not self._allow_filtering:
#if the query is not on an indexed field
# if the query is not on an indexed field
if not any([w.index for w in equal_ops]):
if not any([w.partition_key for w in equal_ops]) and not token_comparison:
raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset')
@@ -783,9 +789,9 @@ class ModelQuerySet(AbstractQuerySet):
def _get_result_constructor(self):
""" Returns a function that will be used to instantiate query results """
if not self._values_list: # we want models
if not self._values_list: # we want models
return lambda rows: self.model._construct_instance(rows)
elif self._flat_values_list: # the user has requested flattened list (1 value per row)
elif self._flat_values_list: # the user has requested flattened list (1 value per row)
return lambda row: row.popitem()[1]
else:
return lambda row: self._get_row_value_list(self._only_fields, row)
@@ -803,7 +809,7 @@ class ModelQuerySet(AbstractQuerySet):
if column is None:
raise QueryException("Can't resolve the column name: '{}'".format(colname))
#validate the column selection
# validate the column selection
if not column.primary_key:
raise QueryException(
"Can't order on '{}', can only order on (clustered) primary keys".format(colname))
@@ -958,7 +964,7 @@ class ModelQuerySet(AbstractQuerySet):
# we should not provide default values in this use case.
val = col.validate(val)
if val is None:
nulled_columns.add(col_name)
continue
@@ -1072,11 +1078,11 @@ class DMLQuery(object):
timestamp=self._timestamp, transactions=self._transaction)
for name, col in self.instance._clustering_keys.items():
null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None))
#get defined fields and their column names
# get defined fields and their column names
for name, col in self.model._columns.items():
# if clustering key is null, don't include non static columns
if null_clustering_key and not col.static and not col.partition_key:
continue
continue
if not col.is_primary_key:
val = getattr(self.instance, name, None)
val_mgr = self.instance._values[name]
@@ -1088,19 +1094,24 @@ class DMLQuery(object):
# don't update something if it hasn't changed
if not val_mgr.changed and not isinstance(col, columns.Counter):
continue
static_changed_only = static_changed_only and col.static
if isinstance(col, (columns.BaseContainerColumn, columns.Counter)):
# get appropriate clause
if isinstance(col, columns.List): klass = ListUpdateClause
elif isinstance(col, columns.Map): klass = MapUpdateClause
elif isinstance(col, columns.Set): klass = SetUpdateClause
elif isinstance(col, columns.Counter): klass = CounterUpdateClause
else: raise RuntimeError
if isinstance(col, columns.List):
klass = ListUpdateClause
elif isinstance(col, columns.Map):
klass = MapUpdateClause
elif isinstance(col, columns.Set):
klass = SetUpdateClause
elif isinstance(col, columns.Counter):
klass = CounterUpdateClause
else:
raise RuntimeError
# do the stuff
clause = klass(col.db_field_name, val,
previous=val_mgr.previous_value, column=col)
previous=val_mgr.previous_value, column=col)
if clause.get_context_size() > 0:
statement.add_assignment_clause(clause)
else:
@@ -1145,7 +1156,7 @@ class DMLQuery(object):
static_save_only = static_save_only and col._val_is_null(getattr(self.instance, name, None))
for name, col in self.instance._columns.items():
if static_save_only and not col.static and not col.partition_key:
continue
continue
val = getattr(self.instance, name, None)
if col._val_is_null(val):
if self.instance._values[name].changed:
@@ -1171,7 +1182,9 @@ class DMLQuery(object):
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp)
for name, col in self.model._primary_keys.items():
if (not col.partition_key) and (getattr(self.instance, name) is None): continue
if (not col.partition_key) and (getattr(self.instance, name) is None):
continue
ds.add_where_clause(WhereClause(
col.db_field_name,
EqualsOperator(),

View File

@@ -7,7 +7,8 @@ from cassandra.cqlengine.functions import QueryValue
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator
class StatementException(Exception): pass
class StatementException(Exception):
pass
class UnicodeMixin(object):
@@ -16,6 +17,7 @@ class UnicodeMixin(object):
else:
__str__ = lambda x: six.text_type(x).encode('utf-8')
class ValueQuoter(UnicodeMixin):
def __init__(self, value):
@@ -28,7 +30,7 @@ class ValueQuoter(UnicodeMixin):
elif isinstance(self.value, (list, tuple)):
return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']'
elif isinstance(self.value, dict):
return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k,v in self.value.items()]) + '}'
return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}'
elif isinstance(self.value, set):
return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}'
return cql_quote(self.value)
@@ -215,7 +217,8 @@ class SetUpdateClause(ContainerUpdateClause):
self._analyzed = True
def get_context_size(self):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
if (self.previous is None and
not self._assignments and
self._additions is None and
@@ -224,7 +227,8 @@ class SetUpdateClause(ContainerUpdateClause):
return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals))
def update_context(self, ctx):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
ctx_id = self.context_id
if (self.previous is None and
self._assignments is None and
@@ -250,7 +254,8 @@ class ListUpdateClause(ContainerUpdateClause):
self._prepend = None
def __unicode__(self):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
qs = []
ctx_id = self.context_id
if self._assignments is not None:
@@ -267,11 +272,13 @@ class ListUpdateClause(ContainerUpdateClause):
return ', '.join(qs)
def get_context_size(self):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend))
def update_context(self, ctx):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
ctx_id = self.context_id
if self._assignments is not None:
ctx[str(ctx_id)] = self._to_database(self._assignments)
@@ -313,13 +320,13 @@ class ListUpdateClause(ContainerUpdateClause):
else:
# the max start idx we want to compare
search_space = len(self.value) - max(0, len(self.previous)-1)
search_space = len(self.value) - max(0, len(self.previous) - 1)
# the size of the sub lists we want to look at
search_size = len(self.previous)
for i in range(search_space):
#slice boundary
# slice boundary
j = i + search_size
sub = self.value[i:j]
idx_cmp = lambda idx: self.previous[idx] == sub[idx]
@@ -354,13 +361,15 @@ class MapUpdateClause(ContainerUpdateClause):
self._analyzed = True
def get_context_size(self):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
if self.previous is None and not self._updates:
return 1
return len(self._updates or []) * 2
def update_context(self, ctx):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
ctx_id = self.context_id
if self.previous is None and not self._updates:
ctx[str(ctx_id)] = {}
@@ -372,7 +381,8 @@ class MapUpdateClause(ContainerUpdateClause):
ctx_id += 2
def __unicode__(self):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
qs = []
ctx_id = self.context_id
@@ -439,16 +449,19 @@ class MapDeleteClause(BaseDeleteClause):
self._analyzed = True
def update_context(self, ctx):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
for idx, key in enumerate(self._removals):
ctx[str(self.context_id + idx)] = key
def get_context_size(self):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
return len(self._removals)
def __unicode__(self):
if not self._analyzed: self._analyze()
if not self._analyzed:
self._analyze()
return ', '.join(['"{}"[%({})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))])
@@ -521,7 +534,6 @@ class BaseCQLStatement(UnicodeMixin):
def __unicode__(self):
raise NotImplementedError
def __repr__(self):
return self.__unicode__()
@@ -645,13 +657,12 @@ class InsertStatement(AssignmentStatement):
ttl=None,
timestamp=None,
if_not_exists=False):
super(InsertStatement, self).__init__(
table,
assignments=assignments,
consistency=consistency,
where=where,
ttl=ttl,
timestamp=timestamp)
super(InsertStatement, self).__init__(table,
assignments=assignments,
consistency=consistency,
where=where,
ttl=ttl,
timestamp=timestamp)
self.if_not_exists = if_not_exists
@@ -813,4 +824,3 @@ class DeleteStatement(BaseCQLStatement):
qs += [self._where]
return ' '.join(qs)