From 62c25d60e9937bc1c9faf2d1e6445d029543db5d Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Tue, 10 Feb 2015 11:47:22 -0600 Subject: [PATCH] PEP-8 cleanup --- cassandra/cqlengine/columns.py | 89 ++++++++++++++---------- cassandra/cqlengine/connection.py | 14 ++-- cassandra/cqlengine/exceptions.py | 28 ++++++-- cassandra/cqlengine/functions.py | 5 +- cassandra/cqlengine/management.py | 53 +++++++------- cassandra/cqlengine/models.py | 112 ++++++++++++++++-------------- cassandra/cqlengine/named.py | 9 ++- cassandra/cqlengine/operators.py | 6 +- cassandra/cqlengine/query.py | 109 ++++++++++++++++------------- cassandra/cqlengine/statements.py | 58 +++++++++------- 10 files changed, 282 insertions(+), 201 deletions(-) diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py index d27a9386..2181413b 100644 --- a/cassandra/cqlengine/columns.py +++ b/cassandra/cqlengine/columns.py @@ -1,4 +1,3 @@ -#column field types from copy import deepcopy, copy from datetime import datetime from datetime import date @@ -86,7 +85,7 @@ class ValueQuoter(object): class Column(object): - #the cassandra type this column maps to + # the cassandra type this column maps to db_type = None value_manager = BaseValueManager @@ -161,13 +160,13 @@ class Column(object): self.required = required self.clustering_order = clustering_order self.polymorphic_key = polymorphic_key - #the column name in the model definition + # the column name in the model definition self.column_name = None self.static = static self.value = None - #keep track of instantiation order + # keep track of instantiation order self.position = Column.instance_counter Column.instance_counter += 1 @@ -266,17 +265,18 @@ class Blob(Column): return bytearray(val) def to_python(self, value): - #return value[2:].decode('hex') return value Bytes = Blob + class Ascii(Column): """ Stores a US-ASCII character string """ db_type = 'ascii' + class Inet(Column): """ Stores an IP address in IPv4 or IPv6 format @@ -309,7 +309,8 @@ class Text(Column): def validate(self, value): value = super(Text, self).validate(value) - if value is None: return + if value is None: + return if not isinstance(value, (six.string_types, bytearray)) and value is not None: raise ValidationError('{} {} is not a string'.format(self.column_name, type(value))) if self.max_length: @@ -330,7 +331,8 @@ class Integer(Column): def validate(self, value): val = super(Integer, self).validate(value) - if val is None: return + if val is None: + return try: return int(val) except (TypeError, ValueError): @@ -409,7 +411,8 @@ class DateTime(Column): db_type = 'timestamp' def to_python(self, value): - if value is None: return + if value is None: + return if isinstance(value, datetime): return value elif isinstance(value, date): @@ -421,7 +424,8 @@ class DateTime(Column): def to_database(self, value): value = super(DateTime, self).to_database(value) - if value is None: return + if value is None: + return if not isinstance(value, datetime): if isinstance(value, date): value = datetime(value.year, value.month, value.day) @@ -443,7 +447,8 @@ class Date(Column): db_type = 'timestamp' def to_python(self, value): - if value is None: return + if value is None: + return if isinstance(value, datetime): return value.date() elif isinstance(value, date): @@ -455,7 +460,8 @@ class Date(Column): def to_database(self, value): value = super(Date, self).to_database(value) - if value is None: return + if value is None: + return if isinstance(value, datetime): value = value.date() if not isinstance(value, date): @@ -474,9 +480,11 @@ class UUID(Column): def validate(self, value): val = super(UUID, self).validate(value) - if val is None: return + if val is None: + return from uuid import UUID as _UUID - if isinstance(val, _UUID): return val + if isinstance(val, _UUID): + return val if isinstance(val, six.string_types) and self.re_uuid.match(val): return _UUID(val) raise ValidationError("{} {} is not a valid uuid".format(self.column_name, value)) @@ -510,7 +518,7 @@ class TimeUUID(UUID): epoch = datetime(1970, 1, 1, tzinfo=dt.tzinfo) offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 - timestamp = (dt - epoch).total_seconds() - offset + timestamp = (dt - epoch).total_seconds() - offset node = None clock_seq = None @@ -529,7 +537,7 @@ class TimeUUID(UUID): if node is None: node = getnode() return pyUUID(fields=(time_low, time_mid, time_hi_version, - clock_seq_hi_variant, clock_seq_low, node), version=1) + clock_seq_hi_variant, clock_seq_low, node), version=1) class Boolean(Column): @@ -563,7 +571,8 @@ class Float(Column): def validate(self, value): value = super(Float, self).validate(value) - if value is None: return + if value is None: + return try: return float(value) except (TypeError, ValueError): @@ -586,7 +595,8 @@ class Decimal(Column): from decimal import Decimal as _Decimal from decimal import InvalidOperation val = super(Decimal, self).validate(value) - if val is None: return + if val is None: + return try: return _Decimal(val) except InvalidOperation: @@ -679,7 +689,8 @@ class Set(BaseContainerColumn): def validate(self, value): val = super(Set, self).validate(value) - if val is None: return + if val is None: + return types = (set,) if self.strict else (set, list, tuple) if not isinstance(val, types): if self.strict: @@ -693,18 +704,19 @@ class Set(BaseContainerColumn): return {self.value_col.validate(v) for v in val} def to_python(self, value): - if value is None: return set() + if value is None: + return set() return {self.value_col.to_python(v) for v in value} def to_database(self, value): - if value is None: return None + if value is None: + return None - if isinstance(value, self.Quoter): return value + if isinstance(value, self.Quoter): + return value return self.Quoter({self.value_col.to_database(v) for v in value}) - - class List(BaseContainerColumn): """ Stores a list of ordered values @@ -727,7 +739,8 @@ class List(BaseContainerColumn): def validate(self, value): val = super(List, self).validate(value) - if val is None: return + if val is None: + return if not isinstance(val, (set, list, tuple)): raise ValidationError('{} {} is not a list object'.format(self.column_name, val)) if None in val: @@ -735,12 +748,15 @@ class List(BaseContainerColumn): return [self.value_col.validate(v) for v in val] def to_python(self, value): - if value is None: return [] + if value is None: + return [] return [self.value_col.to_python(v) for v in value] def to_database(self, value): - if value is None: return None - if isinstance(value, self.Quoter): return value + if value is None: + return None + if isinstance(value, self.Quoter): + return value return self.Quoter([self.value_col.to_database(v) for v in value]) @@ -757,7 +773,7 @@ class Map(BaseContainerColumn): def __str__(self): cq = cql_quote - return '{' + ', '.join([cq(k) + ':' + cq(v) for k,v in self.value.items()]) + '}' + return '{' + ', '.join([cq(k) + ':' + cq(v) for k, v in self.value.items()]) + '}' def get(self, key): return self.value.get(key) @@ -802,22 +818,24 @@ class Map(BaseContainerColumn): def validate(self, value): val = super(Map, self).validate(value) - if val is None: return + if val is None: + return if not isinstance(val, dict): raise ValidationError('{} {} is not a dict object'.format(self.column_name, val)) - return {self.key_col.validate(k):self.value_col.validate(v) for k,v in val.items()} + return {self.key_col.validate(k): self.value_col.validate(v) for k, v in val.items()} def to_python(self, value): if value is None: return {} if value is not None: - return {self.key_col.to_python(k): self.value_col.to_python(v) for k,v in value.items()} + return {self.key_col.to_python(k): self.value_col.to_python(v) for k, v in value.items()} def to_database(self, value): - if value is None: return None - if isinstance(value, self.Quoter): return value - return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()}) - + if value is None: + return None + if isinstance(value, self.Quoter): + return value + return self.Quoter({self.key_col.to_database(k): self.value_col.to_database(v) for k, v in value.items()}) class _PartitionKeysToken(Column): @@ -842,4 +860,3 @@ class _PartitionKeysToken(Column): def get_cql(self): return "token({})".format(", ".join(c.cql for c in self.partition_columns)) - diff --git a/cassandra/cqlengine/connection.py b/cassandra/cqlengine/connection.py index c87fc495..d4b1dc96 100644 --- a/cassandra/cqlengine/connection.py +++ b/cassandra/cqlengine/connection.py @@ -1,6 +1,6 @@ from collections import namedtuple -import six import logging +import six from cassandra import ConsistencyLevel from cassandra.cluster import Cluster, _NOT_SET, NoHostAvailable @@ -9,11 +9,13 @@ from cassandra.cqlengine.exceptions import CQLEngineException, UndefinedKeyspace from cassandra.cqlengine.statements import BaseCQLStatement -LOG = logging.getLogger('cqlengine.cql') +log = logging.getLogger(__name__) + NOT_SET = _NOT_SET # required for passing timeout to Session.execute -class CQLConnectionError(CQLEngineException): pass +class CQLConnectionError(CQLEngineException): + pass Host = namedtuple('Host', ['name', 'port']) @@ -22,6 +24,7 @@ session = None lazy_connect_args = None default_consistency_level = None + def setup( hosts, default_keyspace, @@ -94,21 +97,24 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET): elif isinstance(query, six.string_types): query = SimpleStatement(query, consistency_level=consistency_level) - LOG.info(query.query_string) + log.debug(query.query_string) params = params or {} result = session.execute(query, params, timeout=timeout) return result + def get_session(): handle_lazy_connect() return session + def get_cluster(): handle_lazy_connect() return cluster + def handle_lazy_connect(): global lazy_connect_args if lazy_connect_args: diff --git a/cassandra/cqlengine/exceptions.py b/cassandra/cqlengine/exceptions.py index 19b5c1fe..8f60c90a 100644 --- a/cassandra/cqlengine/exceptions.py +++ b/cassandra/cqlengine/exceptions.py @@ -1,8 +1,22 @@ -#cqlengine exceptions -class CQLEngineException(Exception): pass -class ModelException(CQLEngineException): pass -class ValidationError(CQLEngineException): pass +class CQLEngineException(Exception): + pass -class UndefinedKeyspaceException(CQLEngineException): pass -class LWTException(CQLEngineException): pass -class IfNotExistsWithCounterColumn(CQLEngineException): pass + +class ModelException(CQLEngineException): + pass + + +class ValidationError(CQLEngineException): + pass + + +class UndefinedKeyspaceException(CQLEngineException): + pass + + +class LWTException(CQLEngineException): + pass + + +class IfNotExistsWithCounterColumn(CQLEngineException): + pass diff --git a/cassandra/cqlengine/functions.py b/cassandra/cqlengine/functions.py index 78d42479..c7703239 100644 --- a/cassandra/cqlengine/functions.py +++ b/cassandra/cqlengine/functions.py @@ -5,12 +5,14 @@ import sys from cassandra.cqlengine.exceptions import ValidationError # move to central spot + class UnicodeMixin(object): if sys.version_info > (3, 0): __str__ = lambda x: x.__unicode__() else: __str__ = lambda x: six.text_type(x).encode('utf-8') + class QueryValue(UnicodeMixin): """ Base class for query filter values. Subclasses of these classes can @@ -42,6 +44,8 @@ class BaseQueryFunction(QueryValue): be passed into .filter() and will be translated into CQL functions in the resulting query """ + pass + class MinTimeUUID(BaseQueryFunction): """ @@ -123,4 +127,3 @@ class Token(BaseQueryFunction): def update_context(self, ctx): for i, (col, val) in enumerate(zip(self._columns, self.value)): ctx[str(self.context_id + i)] = col.to_database(val) - diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py index f9c963a0..d5b3f817 100644 --- a/cassandra/cqlengine/management.py +++ b/cassandra/cqlengine/management.py @@ -12,11 +12,12 @@ from cassandra.cqlengine.named import NamedTable Field = namedtuple('Field', ['name', 'type']) -logger = logging.getLogger(__name__) +log = logging.getLogger(__name__) # system keyspaces schema_columnfamilies = NamedTable('system', 'schema_columnfamilies') + def create_keyspace(name, strategy_class, replication_factor, durable_writes=True, **replication_values): """ *Deprecated - this will likely be repaced with something specialized per replication strategy.* @@ -38,10 +39,10 @@ def create_keyspace(name, strategy_class, replication_factor, durable_writes=Tru cluster = get_cluster() if name not in cluster.metadata.keyspaces: - #try the 1.2 method + # try the 1.2 method replication_map = { 'class': strategy_class, - 'replication_factor':replication_factor + 'replication_factor': replication_factor } replication_map.update(replication_values) if strategy_class.lower() != 'simplestrategy': @@ -76,6 +77,7 @@ def delete_keyspace(name): if name in cluster.metadata.keyspaces: execute("DROP KEYSPACE {}".format(name)) + def sync_table(model): """ Inspects the model and creates / updates the corresponding table and columns. @@ -95,8 +97,7 @@ def sync_table(model): if model.__abstract__: raise CQLEngineException("cannot create table from abstract model") - - #construct query string + # construct query string cf_name = model.column_family_name() raw_cf_name = model.column_family_name(include_keyspace=False) @@ -107,7 +108,7 @@ def sync_table(model): keyspace = cluster.metadata.keyspaces[ks_name] tables = keyspace.tables - #check for an existing column family + # check for an existing column family if raw_cf_name not in tables: qs = get_create_table(model) @@ -123,20 +124,21 @@ def sync_table(model): fields = get_fields(model) field_names = [x.name for x in fields] for name, col in model._columns.items(): - if col.primary_key or col.partition_key: continue # we can't mess with the PK - if col.db_field_name in field_names: continue # skip columns already defined + if col.primary_key or col.partition_key: + continue # we can't mess with the PK + if col.db_field_name in field_names: + continue # skip columns already defined # add missing column using the column def query = "ALTER TABLE {} add {}".format(cf_name, col.get_column_def()) - logger.debug(query) + log.debug(query) execute(query) update_compaction(model) - table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name] - indexes = [c for n,c in model._columns.items() if c.index] + indexes = [c for n, c in model._columns.items() if c.index] for column in indexes: if table.columns[column.db_field_name].index: @@ -148,20 +150,23 @@ def sync_table(model): qs = ' '.join(qs) execute(qs) + def get_create_table(model): cf_name = model.column_family_name() qs = ['CREATE TABLE {}'.format(cf_name)] - #add column types - pkeys = [] # primary keys - ckeys = [] # clustering keys - qtypes = [] # field types + # add column types + pkeys = [] # primary keys + ckeys = [] # clustering keys + qtypes = [] # field types + def add_column(col): s = col.get_column_def() if col.primary_key: keys = (pkeys if col.partition_key else ckeys) keys.append('"{}"'.format(col.db_field_name)) qtypes.append(s) + for name, col in model._columns.items(): add_column(col) @@ -172,9 +177,9 @@ def get_create_table(model): with_qs = [] table_properties = ['bloom_filter_fp_chance', 'caching', 'comment', - 'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds', - 'index_interval', 'memtable_flush_period_in_ms', 'populate_io_cache_on_flush', - 'read_repair_chance', 'replicate_on_write'] + 'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds', + 'index_interval', 'memtable_flush_period_in_ms', 'populate_io_cache_on_flush', + 'read_repair_chance', 'replicate_on_write'] for prop_name in table_properties: prop_value = getattr(model, '__{}__'.format(prop_name), None) if prop_value is not None: @@ -211,9 +216,9 @@ def get_compaction_options(model): if not model.__compaction__: return {} - result = {'class':model.__compaction__} + result = {'class': model.__compaction__} - def setter(key, limited_to_strategy = None): + def setter(key, limited_to_strategy=None): """ sets key in result, checking if the key is limited to either SizeTiered or Leveled :param key: one of the compaction options, like "bucket_high" @@ -270,6 +275,7 @@ def get_table_settings(model): table = cluster.metadata.keyspaces[ks].tables[table] return table + def update_compaction(model): """Updates the compaction options for the given model if necessary. @@ -279,7 +285,7 @@ def update_compaction(model): `False` otherwise. :rtype: bool """ - logger.debug("Checking %s for compaction differences", model) + log.debug("Checking %s for compaction differences", model) table = get_table_settings(model) existing_options = table.options.copy() @@ -311,7 +317,7 @@ def update_compaction(model): options = json.dumps(options).replace('"', "'") cf_name = model.column_family_name() query = "ALTER TABLE {} with compaction = {}".format(cf_name, options) - logger.debug(query) + log.debug(query) execute(query) return True @@ -339,6 +345,3 @@ def drop_table(model): execute('drop table {};'.format(model.column_family_name(include_keyspace=True))) except KeyError: pass - - - diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index 298e9d93..bbb4b964 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -1,4 +1,5 @@ import re +import six from cassandra.cqlengine import columns from cassandra.cqlengine.exceptions import ModelException, CQLEngineException, ValidationError @@ -7,16 +8,20 @@ from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned from cassandra.util import OrderedDict -class ModelDefinitionException(ModelException): pass + +class ModelDefinitionException(ModelException): + pass -class PolyMorphicModelException(ModelException): pass +class PolyMorphicModelException(ModelException): + pass -DEFAULT_KEYSPACE = None class UndefinedKeyspaceWarning(Warning): pass +DEFAULT_KEYSPACE = None + class hybrid_classmethod(object): """ @@ -106,7 +111,7 @@ class TTLDescriptor(object): """ def __get__(self, instance, model): if instance: - #instance = copy.deepcopy(instance) + # instance = copy.deepcopy(instance) # instance method def ttl_setter(ts): instance._ttl = ts @@ -124,6 +129,7 @@ class TTLDescriptor(object): def __call__(self, *args, **kwargs): raise NotImplementedError + class TimestampDescriptor(object): """ returns a query set descriptor with a timestamp specified @@ -138,10 +144,10 @@ class TimestampDescriptor(object): return model.objects.timestamp - def __call__(self, *args, **kwargs): raise NotImplementedError + class IfNotExistsDescriptor(object): """ return a query set descriptor with a if_not_exists flag specified @@ -159,13 +165,14 @@ class IfNotExistsDescriptor(object): def __call__(self, *args, **kwargs): raise NotImplementedError + class ConsistencyDescriptor(object): """ returns a query set descriptor if called on Class, instance if it was an instance call """ def __get__(self, instance, model): if instance: - #instance = copy.deepcopy(instance) + # instance = copy.deepcopy(instance) def consistency_setter(consistency): instance.__consistency__ = consistency return instance @@ -229,7 +236,7 @@ class ColumnDescriptor(object): """ try: return instance._values[self.column.column_name].getval() - except AttributeError as e: + except AttributeError: return self.query_evaluator def __set__(self, instance, value): @@ -258,9 +265,11 @@ class BaseModel(object): The base model class, don't inherit from this, inherit from Model, defined below """ - class DoesNotExist(_DoesNotExist): pass + class DoesNotExist(_DoesNotExist): + pass - class MultipleObjectsReturned(_MultipleObjectsReturned): pass + class MultipleObjectsReturned(_MultipleObjectsReturned): + pass objects = QuerySetDescriptor() ttl = TTLDescriptor() @@ -269,7 +278,7 @@ class BaseModel(object): # custom timestamps, see USING TIMESTAMP X timestamp = TimestampDescriptor() - + if_not_exists = IfNotExistsDescriptor() # _len is lazily created by __len__ @@ -302,8 +311,7 @@ class BaseModel(object): __queryset__ = ModelQuerySet __dmlquery__ = DMLQuery - __consistency__ = None # can be set per query - + __consistency__ = None # can be set per query # Additional table properties __bloom_filter_fp_chance__ = None @@ -318,9 +326,9 @@ class BaseModel(object): __read_repair_chance__ = None __replicate_on_write__ = None - _timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP) + _timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP) - _if_not_exists = False # optional if_not_exists flag to check existence before insertion + _if_not_exists = False # optional if_not_exists flag to check existence before insertion def __init__(self, **values): self._values = {} @@ -343,21 +351,19 @@ class BaseModel(object): self._batch = None self._timeout = NOT_SET - def __repr__(self): """ Pretty printing of models by their primary key """ return '{} <{}>'.format(self.__class__.__name__, - ', '.join(('{}={}'.format(k, getattr(self, k)) for k,v in six.iteritems(self._primary_keys))) + ', '.join(('{}={}'.format(k, getattr(self, k)) for k, v in six.iteritems(self._primary_keys))) ) - - @classmethod def _discover_polymorphic_submodels(cls): if not cls._is_polymorphic_base: raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes') + def _discover(klass): if not klass._is_polymorphic_base and klass.__polymorphic_key__ is not None: cls._polymorphic_map[klass.__polymorphic_key__] = klass @@ -381,7 +387,7 @@ class BaseModel(object): # and translate that into our local fields # the db_map is a db_field -> model field map items = values.items() - field_dict = dict([(cls._db_map.get(k, k),v) for k,v in items]) + field_dict = dict([(cls._db_map.get(k, k), v) for k, v in items]) if cls._is_polymorphic: poly_key = field_dict.get(cls._polymorphic_column_name) @@ -421,8 +427,9 @@ class BaseModel(object): :return: """ - if not self._is_persisted: return False - pks = self._primary_keys.keys() + if not self._is_persisted: + return False + return all([not self._values[k].changed for k in self._primary_keys]) @classmethod @@ -481,11 +488,14 @@ class BaseModel(object): ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s) cf_name += ccase(cls.__name__) - #trim to less than 48 characters or cassandra will complain + # trim to less than 48 characters or cassandra will complain cf_name = cf_name[-48:] cf_name = cf_name.lower() cf_name = re.sub(r'^_+', '', cf_name) - if not include_keyspace: return cf_name + + if not include_keyspace: + return cf_name + return '{}.{}'.format(cls._get_keyspace(), cf_name) def validate(self): @@ -499,8 +509,7 @@ class BaseModel(object): val = col.validate(v) setattr(self, name, val) - ### Let an instance be used like a dict of its columns keys/values - + # Let an instance be used like a dict of its columns keys/values def __iter__(self): """ Iterate over column ids. """ for column_id in self._columns.keys(): @@ -620,7 +629,6 @@ class BaseModel(object): else: setattr(self, self._polymorphic_column_name, self.__polymorphic_key__) - is_new = self.pk is None self.validate() self.__dmlquery__(self.__class__, self, batch=self._batch, @@ -631,7 +639,7 @@ class BaseModel(object): transaction=self._transaction, timeout=self._timeout).save() - #reset the value managers + # reset the value managers for v in self._values.values(): v.reset_previous_value() self._is_persisted = True @@ -680,7 +688,7 @@ class BaseModel(object): transaction=self._transaction, timeout=self._timeout).update() - #reset the value managers + # reset the value managers for v in self._values.values(): v.reset_previous_value() self._is_persisted = True @@ -704,7 +712,7 @@ class BaseModel(object): """ Returns a list of the columns that have been updated since instantiation or save """ - return [k for k,v in self._values.items() if v.changed] + return [k for k, v in self._values.items() if v.changed] @classmethod def _class_batch(cls, batch): @@ -715,29 +723,28 @@ class BaseModel(object): self._batch = batch return self - batch = hybrid_classmethod(_class_batch, _inst_batch) class ModelMetaClass(type): def __new__(cls, name, bases, attrs): - #move column definitions into columns dict - #and set default column names + # move column definitions into columns dict + # and set default column names column_dict = OrderedDict() primary_keys = OrderedDict() pk_name = None - #get inherited properties + # get inherited properties inherited_columns = OrderedDict() for base in bases: - for k,v in getattr(base, '_defined_columns', {}).items(): - inherited_columns.setdefault(k,v) + for k, v in getattr(base, '_defined_columns', {}).items(): + inherited_columns.setdefault(k, v) - #short circuit __abstract__ inheritance + # short circuit __abstract__ inheritance is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False) - #short circuit __polymorphic_key__ inheritance + # short circuit __polymorphic_key__ inheritance attrs['__polymorphic_key__'] = attrs.get('__polymorphic_key__', None) def _transform_column(col_name, col_obj): @@ -745,11 +752,10 @@ class ModelMetaClass(type): if col_obj.primary_key: primary_keys[col_name] = col_obj col_obj.set_column_name(col_name) - #set properties + # set properties attrs[col_name] = ColumnDescriptor(col_obj) - column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.Column)] - #column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position)) + column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] column_definitions = sorted(column_definitions, key=lambda x: x[1].position) is_polymorphic_base = any([c[1].polymorphic_key for c in column_definitions]) @@ -780,7 +786,7 @@ class ModelMetaClass(type): defined_columns = OrderedDict(column_definitions) # check for primary key - if not is_abstract and not any([v.primary_key for k,v in column_definitions]): + if not is_abstract and not any([v.primary_key for k, v in column_definitions]): raise ModelDefinitionException("At least 1 primary key is required.") counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)] @@ -790,7 +796,7 @@ class ModelMetaClass(type): has_partition_keys = any(v.partition_key for (k, v) in column_definitions) - #transform column definitions + # transform column definitions for k, v in column_definitions: # don't allow a column with the same name as a built-in attribute or method if k in BaseModel.__dict__: @@ -810,7 +816,7 @@ class ModelMetaClass(type): partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key) - #setup partition key shortcut + # setup partition key shortcut if len(partition_keys) == 0: if not is_abstract: raise ModelException("at least one partition key must be defined") @@ -835,12 +841,12 @@ class ModelMetaClass(type): raise ModelException("invalid clustering order {} for column {}".format(repr(v.clustering_order), v.db_field_name)) col_names.add(v.db_field_name) - #create db_name -> model name map for loading + # create db_name -> model name map for loading db_map = {} for field_name, col in column_dict.items(): db_map[col.db_field_name] = field_name - #add management members to the class + # add management members to the class attrs['_columns'] = column_dict attrs['_primary_keys'] = primary_keys attrs['_defined_columns'] = defined_columns @@ -862,29 +868,31 @@ class ModelMetaClass(type): attrs['_polymorphic_column_name'] = polymorphic_column_name attrs['_polymorphic_map'] = {} if is_polymorphic_base else None - #setup class exceptions + # setup class exceptions DoesNotExistBase = None for base in bases: DoesNotExistBase = getattr(base, 'DoesNotExist', None) - if DoesNotExistBase is not None: break + if DoesNotExistBase is not None: + break + DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist) attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {}) MultipleObjectsReturnedBase = None for base in bases: MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None) - if MultipleObjectsReturnedBase is not None: break + if MultipleObjectsReturnedBase is not None: + break + MultipleObjectsReturnedBase = DoesNotExistBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned) attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {}) - #create the class and add a QuerySet to it + # create the class and add a QuerySet to it klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs) return klass -import six - @six.add_metaclass(ModelMetaClass) class Model(BaseModel): __abstract__ = True @@ -906,7 +914,7 @@ class Model(BaseModel): __default_ttl__ = None """ *Optional* The default ttl used by this model. - + This can be overridden by using the :meth:`~.ttl` method. """ diff --git a/cassandra/cqlengine/named.py b/cassandra/cqlengine/named.py index 0f0b86b2..108f8e63 100644 --- a/cassandra/cqlengine/named.py +++ b/cassandra/cqlengine/named.py @@ -3,6 +3,7 @@ from cassandra.cqlengine.query import AbstractQueryableColumn, SimpleQuerySet from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned + class QuerySetDescriptor(object): """ returns a fresh queryset for the given model @@ -63,8 +64,11 @@ class NamedTable(object): objects = QuerySetDescriptor() - class DoesNotExist(_DoesNotExist): pass - class MultipleObjectsReturned(_MultipleObjectsReturned): pass + class DoesNotExist(_DoesNotExist): + pass + + class MultipleObjectsReturned(_MultipleObjectsReturned): + pass def __init__(self, keyspace, name): self.keyspace = keyspace @@ -118,4 +122,3 @@ class NamedKeyspace(object): name that belongs to this keyspace """ return NamedTable(self.name, name) - diff --git a/cassandra/cqlengine/operators.py b/cassandra/cqlengine/operators.py index e609531f..c8f02e14 100644 --- a/cassandra/cqlengine/operators.py +++ b/cassandra/cqlengine/operators.py @@ -2,7 +2,9 @@ import six import sys -class QueryOperatorException(Exception): pass +class QueryOperatorException(Exception): + pass + # move to central spot class UnicodeMixin(object): @@ -31,12 +33,14 @@ class BaseQueryOperator(UnicodeMixin): raise QueryOperatorException("get_operator can only be called from a BaseQueryOperator subclass") if not hasattr(cls, 'opmap'): cls.opmap = {} + def _recurse(klass): if klass.symbol: cls.opmap[klass.symbol.upper()] = klass for subklass in klass.__subclasses__(): _recurse(subklass) pass + _recurse(cls) try: return cls.opmap[symbol.upper()] diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index 2dcef018..b46a57c1 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -1,6 +1,7 @@ import copy from datetime import datetime, timedelta import time +import six from cassandra.cqlengine import columns from cassandra.cqlengine.connection import execute, NOT_SET @@ -17,11 +18,16 @@ from cassandra.cqlengine.statements import (WhereClause, SelectStatement, Delete TransactionClause) -class QueryException(CQLEngineException): pass -class DoesNotExist(QueryException): pass -class MultipleObjectsReturned(QueryException): pass +class QueryException(CQLEngineException): + pass -import six + +class DoesNotExist(QueryException): + pass + + +class MultipleObjectsReturned(QueryException): + pass def check_applied(result): @@ -30,7 +36,7 @@ def check_applied(result): if that value is false, it means our light-weight transaction didn't applied to database. """ - if result and '[applied]' in result[0] and result[0]['[applied]'] == False: + if result and '[applied]' in result[0] and not result[0]['[applied]']: raise LWTException('') @@ -77,8 +83,8 @@ class AbstractQueryableColumn(UnicodeMixin): class BatchType(object): - Unlogged = 'UNLOGGED' - Counter = 'COUNTER' + Unlogged = 'UNLOGGED' + Counter = 'COUNTER' class BatchQuery(object): @@ -194,8 +200,9 @@ class BatchQuery(object): return self def __exit__(self, exc_type, exc_val, exc_tb): - #don't execute if there was an exception by default - if exc_type is not None and not self._execute_on_exception: return + # don't execute if there was an exception by default + if exc_type is not None and not self._execute_on_exception: + return self.execute() @@ -205,29 +212,29 @@ class AbstractQuerySet(object): super(AbstractQuerySet, self).__init__() self.model = model - #Where clause filters + # Where clause filters self._where = [] # Transaction clause filters self._transaction = [] - #ordering arguments + # ordering arguments self._order = [] self._allow_filtering = False - #CQL has a default limit of 10000, it's defined here - #because explicit is better than implicit + # CQL has a default limit of 10000, it's defined here + # because explicit is better than implicit self._limit = 10000 - #see the defer and only methods + # see the defer and only methods self._defer_fields = [] self._only_fields = [] self._values_list = False self._flat_values_list = False - #results cache + # results cache self._con = None self._cur = None self._result_cache = None @@ -265,7 +272,7 @@ class AbstractQuerySet(object): def __deepcopy__(self, memo): clone = self.__class__(self.model) for k, v in self.__dict__.items(): - if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these + if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these clone.__dict__[k] = None elif k == '_batch': # we need to keep the same batch instance across @@ -284,7 +291,7 @@ class AbstractQuerySet(object): self._execute_query() return len(self._result_cache) - #----query generation / execution---- + # ----query generation / execution---- def _select_fields(self): """ returns the fields to select """ @@ -308,7 +315,7 @@ class AbstractQuerySet(object): allow_filtering=self._allow_filtering ) - #----Reads------ + # ----Reads------ def _execute_query(self): if self._batch: @@ -330,7 +337,7 @@ class AbstractQuerySet(object): self._result_idx += 1 self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx]) - #return the connection to the connection pool if we have all objects + # return the connection to the connection pool if we have all objects if self._result_cache and self._result_idx == (len(self._result_cache) - 1): self._con = None self._cur = None @@ -350,7 +357,7 @@ class AbstractQuerySet(object): num_results = len(self._result_cache) if isinstance(s, slice): - #calculate the amount of results that need to be loaded + # calculate the amount of results that need to be loaded end = num_results if s.step is None else s.step if end < 0: end += num_results @@ -359,11 +366,12 @@ class AbstractQuerySet(object): self._fill_result_cache_to_idx(end) return self._result_cache[s.start:s.stop:s.step] else: - #return the object at this index + # return the object at this index s = int(s) - #handle negative indexing - if s < 0: s += num_results + # handle negative indexing + if s < 0: + s += num_results if s >= num_results: raise IndexError @@ -453,7 +461,6 @@ class AbstractQuerySet(object): if not isinstance(val, Token): raise QueryException("Virtual column 'pk__token' may only be compared to Token() values") column = columns._PartitionKeysToken(self.model) - quote_field = False else: raise QueryException("Can't resolve column name: '{}'".format(col_name)) @@ -484,7 +491,7 @@ class AbstractQuerySet(object): Returns a QuerySet filtered on the keyword arguments """ - #add arguments to the where clause filters + # add arguments to the where clause filters if len([x for x in kwargs.values() if x is None]): raise CQLEngineException("None values on filter are not allowed") @@ -497,7 +504,7 @@ class AbstractQuerySet(object): for arg, val in kwargs.items(): col_name, col_op = self._parse_filter_arg(arg) quote_field = True - #resolve column and operator + # resolve column and operator try: column = self.model._get_column(col_name) except KeyError: @@ -519,7 +526,7 @@ class AbstractQuerySet(object): len(val.value), len(partition_columns))) val.set_columns(partition_columns) - #get query operator, or use equals if not supplied + # get query operator, or use equals if not supplied operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') operator = operator_class() @@ -559,8 +566,7 @@ class AbstractQuerySet(object): if len(self._result_cache) == 0: raise self.model.DoesNotExist elif len(self._result_cache) > 1: - raise self.model.MultipleObjectsReturned( - '{} objects found'.format(len(self._result_cache))) + raise self.model.MultipleObjectsReturned('{} objects found'.format(len(self._result_cache))) else: return self[0] @@ -665,7 +671,7 @@ class AbstractQuerySet(object): if clone._defer_fields or clone._only_fields: raise QueryException("QuerySet alread has only or defer fields defined") - #check for strange fields + # check for strange fields missing_fields = [f for f in fields if f not in self.model._columns.keys()] if missing_fields: raise QueryException( @@ -698,7 +704,7 @@ class AbstractQuerySet(object): """ Deletes the contents of a query """ - #validate where clause + # validate where clause partition_key = [x for x in self.model._primary_keys.values()][0] if not any([c.field == partition_key.column_name for c in self._where]): raise QueryException("The partition key must be defined on delete queries") @@ -759,14 +765,14 @@ class ModelQuerySet(AbstractQuerySet): """ def _validate_select_where(self): """ Checks that a filterset will not create invalid select statement """ - #check that there's either a = or IN relationship with a primary key or indexed field + # check that there's either a = or IN relationship with a primary key or indexed field equal_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)] token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) if not any([w.primary_key or w.index for w in equal_ops]) and not token_comparison and not self._allow_filtering: raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field') if not self._allow_filtering: - #if the query is not on an indexed field + # if the query is not on an indexed field if not any([w.index for w in equal_ops]): if not any([w.partition_key for w in equal_ops]) and not token_comparison: raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset') @@ -783,9 +789,9 @@ class ModelQuerySet(AbstractQuerySet): def _get_result_constructor(self): """ Returns a function that will be used to instantiate query results """ - if not self._values_list: # we want models + if not self._values_list: # we want models return lambda rows: self.model._construct_instance(rows) - elif self._flat_values_list: # the user has requested flattened list (1 value per row) + elif self._flat_values_list: # the user has requested flattened list (1 value per row) return lambda row: row.popitem()[1] else: return lambda row: self._get_row_value_list(self._only_fields, row) @@ -803,7 +809,7 @@ class ModelQuerySet(AbstractQuerySet): if column is None: raise QueryException("Can't resolve the column name: '{}'".format(colname)) - #validate the column selection + # validate the column selection if not column.primary_key: raise QueryException( "Can't order on '{}', can only order on (clustered) primary keys".format(colname)) @@ -958,7 +964,7 @@ class ModelQuerySet(AbstractQuerySet): # we should not provide default values in this use case. val = col.validate(val) - + if val is None: nulled_columns.add(col_name) continue @@ -1072,11 +1078,11 @@ class DMLQuery(object): timestamp=self._timestamp, transactions=self._transaction) for name, col in self.instance._clustering_keys.items(): null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) - #get defined fields and their column names + # get defined fields and their column names for name, col in self.model._columns.items(): # if clustering key is null, don't include non static columns if null_clustering_key and not col.static and not col.partition_key: - continue + continue if not col.is_primary_key: val = getattr(self.instance, name, None) val_mgr = self.instance._values[name] @@ -1088,19 +1094,24 @@ class DMLQuery(object): # don't update something if it hasn't changed if not val_mgr.changed and not isinstance(col, columns.Counter): continue - + static_changed_only = static_changed_only and col.static if isinstance(col, (columns.BaseContainerColumn, columns.Counter)): # get appropriate clause - if isinstance(col, columns.List): klass = ListUpdateClause - elif isinstance(col, columns.Map): klass = MapUpdateClause - elif isinstance(col, columns.Set): klass = SetUpdateClause - elif isinstance(col, columns.Counter): klass = CounterUpdateClause - else: raise RuntimeError + if isinstance(col, columns.List): + klass = ListUpdateClause + elif isinstance(col, columns.Map): + klass = MapUpdateClause + elif isinstance(col, columns.Set): + klass = SetUpdateClause + elif isinstance(col, columns.Counter): + klass = CounterUpdateClause + else: + raise RuntimeError # do the stuff clause = klass(col.db_field_name, val, - previous=val_mgr.previous_value, column=col) + previous=val_mgr.previous_value, column=col) if clause.get_context_size() > 0: statement.add_assignment_clause(clause) else: @@ -1145,7 +1156,7 @@ class DMLQuery(object): static_save_only = static_save_only and col._val_is_null(getattr(self.instance, name, None)) for name, col in self.instance._columns.items(): if static_save_only and not col.static and not col.partition_key: - continue + continue val = getattr(self.instance, name, None) if col._val_is_null(val): if self.instance._values[name].changed: @@ -1171,7 +1182,9 @@ class DMLQuery(object): ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp) for name, col in self.model._primary_keys.items(): - if (not col.partition_key) and (getattr(self.instance, name) is None): continue + if (not col.partition_key) and (getattr(self.instance, name) is None): + continue + ds.add_where_clause(WhereClause( col.db_field_name, EqualsOperator(), diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index 917b9862..ec59c470 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -7,7 +7,8 @@ from cassandra.cqlengine.functions import QueryValue from cassandra.cqlengine.operators import BaseWhereOperator, InOperator -class StatementException(Exception): pass +class StatementException(Exception): + pass class UnicodeMixin(object): @@ -16,6 +17,7 @@ class UnicodeMixin(object): else: __str__ = lambda x: six.text_type(x).encode('utf-8') + class ValueQuoter(UnicodeMixin): def __init__(self, value): @@ -28,7 +30,7 @@ class ValueQuoter(UnicodeMixin): elif isinstance(self.value, (list, tuple)): return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' elif isinstance(self.value, dict): - return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k,v in self.value.items()]) + '}' + return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}' elif isinstance(self.value, set): return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}' return cql_quote(self.value) @@ -215,7 +217,8 @@ class SetUpdateClause(ContainerUpdateClause): self._analyzed = True def get_context_size(self): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() if (self.previous is None and not self._assignments and self._additions is None and @@ -224,7 +227,8 @@ class SetUpdateClause(ContainerUpdateClause): return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals)) def update_context(self, ctx): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() ctx_id = self.context_id if (self.previous is None and self._assignments is None and @@ -250,7 +254,8 @@ class ListUpdateClause(ContainerUpdateClause): self._prepend = None def __unicode__(self): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() qs = [] ctx_id = self.context_id if self._assignments is not None: @@ -267,11 +272,13 @@ class ListUpdateClause(ContainerUpdateClause): return ', '.join(qs) def get_context_size(self): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend)) def update_context(self, ctx): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() ctx_id = self.context_id if self._assignments is not None: ctx[str(ctx_id)] = self._to_database(self._assignments) @@ -313,13 +320,13 @@ class ListUpdateClause(ContainerUpdateClause): else: # the max start idx we want to compare - search_space = len(self.value) - max(0, len(self.previous)-1) + search_space = len(self.value) - max(0, len(self.previous) - 1) # the size of the sub lists we want to look at search_size = len(self.previous) for i in range(search_space): - #slice boundary + # slice boundary j = i + search_size sub = self.value[i:j] idx_cmp = lambda idx: self.previous[idx] == sub[idx] @@ -354,13 +361,15 @@ class MapUpdateClause(ContainerUpdateClause): self._analyzed = True def get_context_size(self): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() if self.previous is None and not self._updates: return 1 return len(self._updates or []) * 2 def update_context(self, ctx): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() ctx_id = self.context_id if self.previous is None and not self._updates: ctx[str(ctx_id)] = {} @@ -372,7 +381,8 @@ class MapUpdateClause(ContainerUpdateClause): ctx_id += 2 def __unicode__(self): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() qs = [] ctx_id = self.context_id @@ -439,16 +449,19 @@ class MapDeleteClause(BaseDeleteClause): self._analyzed = True def update_context(self, ctx): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() for idx, key in enumerate(self._removals): ctx[str(self.context_id + idx)] = key def get_context_size(self): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() return len(self._removals) def __unicode__(self): - if not self._analyzed: self._analyze() + if not self._analyzed: + self._analyze() return ', '.join(['"{}"[%({})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))]) @@ -521,7 +534,6 @@ class BaseCQLStatement(UnicodeMixin): def __unicode__(self): raise NotImplementedError - def __repr__(self): return self.__unicode__() @@ -645,13 +657,12 @@ class InsertStatement(AssignmentStatement): ttl=None, timestamp=None, if_not_exists=False): - super(InsertStatement, self).__init__( - table, - assignments=assignments, - consistency=consistency, - where=where, - ttl=ttl, - timestamp=timestamp) + super(InsertStatement, self).__init__(table, + assignments=assignments, + consistency=consistency, + where=where, + ttl=ttl, + timestamp=timestamp) self.if_not_exists = if_not_exists @@ -813,4 +824,3 @@ class DeleteStatement(BaseCQLStatement): qs += [self._where] return ' '.join(qs) -