1040 lines
		
	
	
		
			34 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1040 lines
		
	
	
		
			34 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import copy
 | 
						|
from datetime import datetime
 | 
						|
from uuid import uuid4
 | 
						|
from hashlib import md5
 | 
						|
from time import time
 | 
						|
from uuid import uuid1
 | 
						|
from cqlengine import BaseContainerColumn, BaseValueManager, Map, columns
 | 
						|
from cqlengine.columns import Counter
 | 
						|
 | 
						|
from cqlengine.connection import connection_manager, execute, RowResult
 | 
						|
 | 
						|
from cqlengine.exceptions import CQLEngineException, ValidationError
 | 
						|
from cqlengine.functions import QueryValue, Token
 | 
						|
 | 
						|
#CQL 3 reference:
 | 
						|
#http://www.datastax.com/docs/1.1/references/cql/index
 | 
						|
 | 
						|
class QueryException(CQLEngineException): pass
 | 
						|
class DoesNotExist(QueryException): pass
 | 
						|
class MultipleObjectsReturned(QueryException): pass
 | 
						|
 | 
						|
 | 
						|
class QueryOperatorException(QueryException): pass
 | 
						|
 | 
						|
 | 
						|
class QueryOperator(object):
 | 
						|
    # The symbol that identifies this operator in filter kwargs
 | 
						|
    # ie: colname__<symbol>
 | 
						|
    symbol = None
 | 
						|
 | 
						|
    # The comparator symbol this operator uses in cql
 | 
						|
    cql_symbol = None
 | 
						|
 | 
						|
    QUERY_VALUE_WRAPPER = QueryValue
 | 
						|
 | 
						|
    def __init__(self, column, value):
 | 
						|
        self.column = column
 | 
						|
        self.value = value
 | 
						|
 | 
						|
        if isinstance(value, QueryValue):
 | 
						|
            self.query_value = value
 | 
						|
        else:
 | 
						|
            self.query_value = self.QUERY_VALUE_WRAPPER(value)
 | 
						|
 | 
						|
        #perform validation on this operator
 | 
						|
        self.validate_operator()
 | 
						|
        self.validate_value()
 | 
						|
 | 
						|
    @property
 | 
						|
    def cql(self):
 | 
						|
        """
 | 
						|
        Returns this operator's portion of the WHERE clause
 | 
						|
        """
 | 
						|
        return '{} {} {}'.format(self.column.cql, self.cql_symbol, self.query_value.cql)
 | 
						|
 | 
						|
    def validate_operator(self):
 | 
						|
        """
 | 
						|
        Checks that this operator can be used on the column provided
 | 
						|
        """
 | 
						|
        if self.symbol is None:
 | 
						|
            raise QueryOperatorException(
 | 
						|
                    "{} is not a valid operator, use one with 'symbol' defined".format(
 | 
						|
                        self.__class__.__name__
 | 
						|
                    )
 | 
						|
                )
 | 
						|
        if self.cql_symbol is None:
 | 
						|
            raise QueryOperatorException(
 | 
						|
                    "{} is not a valid operator, use one with 'cql_symbol' defined".format(
 | 
						|
                        self.__class__.__name__
 | 
						|
                    )
 | 
						|
                )
 | 
						|
 | 
						|
    def validate_value(self):
 | 
						|
        """
 | 
						|
        Checks that the compare value works with this operator
 | 
						|
 | 
						|
        Doesn't do anything by default
 | 
						|
        """
 | 
						|
        pass
 | 
						|
 | 
						|
    def get_dict(self):
 | 
						|
        """
 | 
						|
        Returns this operators contribution to the cql.query arg dictionanry
 | 
						|
 | 
						|
        ie: if this column's name is colname, and the identifier is colval,
 | 
						|
        this should return the dict: {'colval':<self.value>}
 | 
						|
        SELECT * FROM column_family WHERE colname=:colval
 | 
						|
        """
 | 
						|
        return self.query_value.get_dict(self.column)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def get_operator(cls, symbol):
 | 
						|
        if not hasattr(cls, 'opmap'):
 | 
						|
            QueryOperator.opmap = {}
 | 
						|
            def _recurse(klass):
 | 
						|
                if klass.symbol:
 | 
						|
                    QueryOperator.opmap[klass.symbol.upper()] = klass
 | 
						|
                for subklass in klass.__subclasses__():
 | 
						|
                    _recurse(subklass)
 | 
						|
                pass
 | 
						|
            _recurse(QueryOperator)
 | 
						|
        try:
 | 
						|
            return QueryOperator.opmap[symbol.upper()]
 | 
						|
        except KeyError:
 | 
						|
            raise QueryOperatorException("{} doesn't map to a QueryOperator".format(symbol))
 | 
						|
 | 
						|
    # equality operator, used by tests
 | 
						|
 | 
						|
    def __eq__(self, op):
 | 
						|
        return self.__class__ is op.__class__ and \
 | 
						|
                self.column.db_field_name == op.column.db_field_name and \
 | 
						|
                self.value == op.value
 | 
						|
 | 
						|
    def __ne__(self, op):
 | 
						|
        return not (self == op)
 | 
						|
 | 
						|
    def __hash__(self):
 | 
						|
        return hash(self.column.db_field_name) ^ hash(self.value)
 | 
						|
 | 
						|
 | 
						|
class EqualsOperator(QueryOperator):
 | 
						|
    symbol = 'EQ'
 | 
						|
    cql_symbol = '='
 | 
						|
 | 
						|
 | 
						|
class IterableQueryValue(QueryValue):
 | 
						|
    def __init__(self, value):
 | 
						|
        try:
 | 
						|
            super(IterableQueryValue, self).__init__(value, [uuid4().hex for i in value])
 | 
						|
        except TypeError:
 | 
						|
            raise QueryException("in operator arguments must be iterable, {} found".format(value))
 | 
						|
 | 
						|
    def get_dict(self, column):
 | 
						|
        return dict((i, column.to_database(v)) for (i, v) in zip(self.identifier, self.value))
 | 
						|
 | 
						|
    def get_cql(self):
 | 
						|
        return '({})'.format(', '.join(':{}'.format(i) for i in self.identifier))
 | 
						|
 | 
						|
 | 
						|
class InOperator(EqualsOperator):
 | 
						|
    symbol = 'IN'
 | 
						|
    cql_symbol = 'IN'
 | 
						|
 | 
						|
    QUERY_VALUE_WRAPPER = IterableQueryValue
 | 
						|
 | 
						|
 | 
						|
class GreaterThanOperator(QueryOperator):
 | 
						|
    symbol = "GT"
 | 
						|
    cql_symbol = '>'
 | 
						|
 | 
						|
 | 
						|
class GreaterThanOrEqualOperator(QueryOperator):
 | 
						|
    symbol = "GTE"
 | 
						|
    cql_symbol = '>='
 | 
						|
 | 
						|
 | 
						|
class LessThanOperator(QueryOperator):
 | 
						|
    symbol = "LT"
 | 
						|
    cql_symbol = '<'
 | 
						|
 | 
						|
 | 
						|
class LessThanOrEqualOperator(QueryOperator):
 | 
						|
    symbol = "LTE"
 | 
						|
    cql_symbol = '<='
 | 
						|
 | 
						|
 | 
						|
class AbstractQueryableColumn(object):
 | 
						|
    """
 | 
						|
    exposes cql query operators through pythons
 | 
						|
    builtin comparator symbols
 | 
						|
    """
 | 
						|
 | 
						|
    def _get_column(self):
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    def in_(self, item):
 | 
						|
        """
 | 
						|
        Returns an in operator
 | 
						|
 | 
						|
        used in where you'd typically want to use python's `in` operator
 | 
						|
        """
 | 
						|
        return InOperator(self._get_column(), item)
 | 
						|
 | 
						|
    def __eq__(self, other):
 | 
						|
        return EqualsOperator(self._get_column(), other)
 | 
						|
 | 
						|
    def __gt__(self, other):
 | 
						|
        return GreaterThanOperator(self._get_column(), other)
 | 
						|
 | 
						|
    def __ge__(self, other):
 | 
						|
        return GreaterThanOrEqualOperator(self._get_column(), other)
 | 
						|
 | 
						|
    def __lt__(self, other):
 | 
						|
        return LessThanOperator(self._get_column(), other)
 | 
						|
 | 
						|
    def __le__(self, other):
 | 
						|
        return LessThanOrEqualOperator(self._get_column(), other)
 | 
						|
 | 
						|
 | 
						|
class BatchType(object):
 | 
						|
    Unlogged    = 'UNLOGGED'
 | 
						|
    Counter     = 'COUNTER'
 | 
						|
 | 
						|
 | 
						|
class BatchQuery(object):
 | 
						|
    """
 | 
						|
    Handles the batching of queries
 | 
						|
 | 
						|
    http://www.datastax.com/docs/1.2/cql_cli/cql/BATCH
 | 
						|
    """
 | 
						|
    _consistency = None
 | 
						|
 | 
						|
    def __init__(self, batch_type=None, timestamp=None, consistency=None):
 | 
						|
        self.queries = []
 | 
						|
        self.batch_type = batch_type
 | 
						|
        if timestamp is not None and not isinstance(timestamp, datetime):
 | 
						|
            raise CQLEngineException('timestamp object must be an instance of datetime')
 | 
						|
        self.timestamp = timestamp
 | 
						|
        self._consistency = consistency
 | 
						|
 | 
						|
    def add_query(self, query, params):
 | 
						|
        self.queries.append((query, params))
 | 
						|
 | 
						|
    def consistency(self, consistency):
 | 
						|
        self._consistency = consistency
 | 
						|
 | 
						|
    def execute(self):
 | 
						|
        if len(self.queries) == 0:
 | 
						|
            # Empty batch is a no-op
 | 
						|
            return
 | 
						|
 | 
						|
        opener = 'BEGIN ' + (self.batch_type + ' ' if self.batch_type else '') + ' BATCH'
 | 
						|
        if self.timestamp:
 | 
						|
            epoch = datetime(1970, 1, 1)
 | 
						|
            ts = long((self.timestamp - epoch).total_seconds() * 1000)
 | 
						|
            opener += ' USING TIMESTAMP {}'.format(ts)
 | 
						|
 | 
						|
        query_list = [opener]
 | 
						|
        parameters = {}
 | 
						|
        for query, params in self.queries:
 | 
						|
            query_list.append('  ' + query)
 | 
						|
            parameters.update(params)
 | 
						|
 | 
						|
        query_list.append('APPLY BATCH;')
 | 
						|
 | 
						|
        execute('\n'.join(query_list), parameters, self._consistency)
 | 
						|
 | 
						|
        self.queries = []
 | 
						|
 | 
						|
    def __enter__(self):
 | 
						|
        return self
 | 
						|
 | 
						|
    def __exit__(self, exc_type, exc_val, exc_tb):
 | 
						|
        #don't execute if there was an exception
 | 
						|
        if exc_type is not None: return
 | 
						|
        self.execute()
 | 
						|
 | 
						|
 | 
						|
class AbstractQuerySet(object):
 | 
						|
 | 
						|
    def __init__(self, model):
 | 
						|
        super(AbstractQuerySet, self).__init__()
 | 
						|
        self.model = model
 | 
						|
 | 
						|
        #Where clause filters
 | 
						|
        self._where = []
 | 
						|
 | 
						|
        #ordering arguments
 | 
						|
        self._order = []
 | 
						|
 | 
						|
        self._allow_filtering = False
 | 
						|
 | 
						|
        #CQL has a default limit of 10000, it's defined here
 | 
						|
        #because explicit is better than implicit
 | 
						|
        self._limit = 10000
 | 
						|
 | 
						|
        #see the defer and only methods
 | 
						|
        self._defer_fields = []
 | 
						|
        self._only_fields = []
 | 
						|
 | 
						|
        self._values_list = False
 | 
						|
        self._flat_values_list = False
 | 
						|
 | 
						|
        #results cache
 | 
						|
        self._con = None
 | 
						|
        self._cur = None
 | 
						|
        self._result_cache = None
 | 
						|
        self._result_idx = None
 | 
						|
 | 
						|
        self._batch = None
 | 
						|
        self._ttl = None
 | 
						|
        self._consistency = None
 | 
						|
 | 
						|
    @property
 | 
						|
    def column_family_name(self):
 | 
						|
        return self.model.column_family_name()
 | 
						|
 | 
						|
    def __unicode__(self):
 | 
						|
        return self._select_query()
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return str(self.__unicode__())
 | 
						|
 | 
						|
    def __call__(self, *args, **kwargs):
 | 
						|
        return self.filter(*args, **kwargs)
 | 
						|
 | 
						|
    def __deepcopy__(self, memo):
 | 
						|
        clone = self.__class__(self.model)
 | 
						|
        for k,v in self.__dict__.items():
 | 
						|
            if k in ['_con', '_cur', '_result_cache', '_result_idx']:
 | 
						|
                clone.__dict__[k] = None
 | 
						|
            elif k == '_batch':
 | 
						|
                # we need to keep the same batch instance across
 | 
						|
                # all queryset clones, otherwise the batched queries
 | 
						|
                # fly off into other batch instances which are never
 | 
						|
                # executed, thx @dokai
 | 
						|
                clone.__dict__[k] = self._batch
 | 
						|
            else:
 | 
						|
                clone.__dict__[k] = copy.deepcopy(v, memo)
 | 
						|
 | 
						|
        return clone
 | 
						|
 | 
						|
    def __len__(self):
 | 
						|
        self._execute_query()
 | 
						|
        return len(self._result_cache)
 | 
						|
 | 
						|
    #----query generation / execution----
 | 
						|
 | 
						|
    def _where_clause(self):
 | 
						|
        """ Returns a where clause based on the given filter args """
 | 
						|
        return ' AND '.join([f.cql for f in self._where])
 | 
						|
 | 
						|
    def _where_values(self):
 | 
						|
        """ Returns the value dict to be passed to the cql query """
 | 
						|
        values = {}
 | 
						|
        for where in self._where:
 | 
						|
            values.update(where.get_dict())
 | 
						|
        return values
 | 
						|
 | 
						|
    def _get_select_statement(self):
 | 
						|
        """ returns the select portion of this queryset's cql statement """
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    def _select_query(self):
 | 
						|
        """
 | 
						|
        Returns a select clause based on the given filter args
 | 
						|
        """
 | 
						|
        qs = [self._get_select_statement()]
 | 
						|
        qs += ['FROM {}'.format(self.column_family_name)]
 | 
						|
 | 
						|
        if self._where:
 | 
						|
            qs += ['WHERE {}'.format(self._where_clause())]
 | 
						|
 | 
						|
        if self._order:
 | 
						|
            qs += ['ORDER BY {}'.format(', '.join(self._order))]
 | 
						|
 | 
						|
        if self._limit:
 | 
						|
            qs += ['LIMIT {}'.format(self._limit)]
 | 
						|
 | 
						|
        if self._allow_filtering:
 | 
						|
            qs += ['ALLOW FILTERING']
 | 
						|
 | 
						|
        return ' '.join(qs)
 | 
						|
 | 
						|
    #----Reads------
 | 
						|
 | 
						|
    def _execute_query(self):
 | 
						|
        if self._batch:
 | 
						|
            raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
 | 
						|
        if self._result_cache is None:
 | 
						|
            columns, self._result_cache = execute(self._select_query(), self._where_values(), self._consistency)
 | 
						|
            self._construct_result = self._get_result_constructor(columns)
 | 
						|
 | 
						|
    def _fill_result_cache_to_idx(self, idx):
 | 
						|
        self._execute_query()
 | 
						|
        if self._result_idx is None:
 | 
						|
            self._result_idx = -1
 | 
						|
 | 
						|
        qty = idx - self._result_idx
 | 
						|
        if qty < 1:
 | 
						|
            return
 | 
						|
        else:
 | 
						|
            for idx in range(qty):
 | 
						|
                self._result_idx += 1
 | 
						|
                self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx])
 | 
						|
 | 
						|
            #return the connection to the connection pool if we have all objects
 | 
						|
            if self._result_cache and self._result_idx == (len(self._result_cache) - 1):
 | 
						|
                self._con = None
 | 
						|
                self._cur = None
 | 
						|
 | 
						|
    def __iter__(self):
 | 
						|
        self._execute_query()
 | 
						|
 | 
						|
        for idx in range(len(self._result_cache)):
 | 
						|
            instance = self._result_cache[idx]
 | 
						|
            if isinstance(instance, RowResult):
 | 
						|
                self._fill_result_cache_to_idx(idx)
 | 
						|
            yield self._result_cache[idx]
 | 
						|
 | 
						|
    def __getitem__(self, s):
 | 
						|
        self._execute_query()
 | 
						|
 | 
						|
        num_results = len(self._result_cache)
 | 
						|
 | 
						|
        if isinstance(s, slice):
 | 
						|
            #calculate the amount of results that need to be loaded
 | 
						|
            end = num_results if s.step is None else s.step
 | 
						|
            if end < 0:
 | 
						|
                end += num_results
 | 
						|
            else:
 | 
						|
                end -= 1
 | 
						|
            self._fill_result_cache_to_idx(end)
 | 
						|
            return self._result_cache[s.start:s.stop:s.step]
 | 
						|
        else:
 | 
						|
            #return the object at this index
 | 
						|
            s = long(s)
 | 
						|
 | 
						|
            #handle negative indexing
 | 
						|
            if s < 0: s += num_results
 | 
						|
 | 
						|
            if s >= num_results:
 | 
						|
                raise IndexError
 | 
						|
            else:
 | 
						|
                self._fill_result_cache_to_idx(s)
 | 
						|
                return self._result_cache[s]
 | 
						|
 | 
						|
    def _get_result_constructor(self, names):
 | 
						|
        """
 | 
						|
        Returns a function that will be used to instantiate query results
 | 
						|
        """
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    def batch(self, batch_obj):
 | 
						|
        """
 | 
						|
        Adds a batch query to the mix
 | 
						|
        :param batch_obj:
 | 
						|
        :return:
 | 
						|
        """
 | 
						|
        if batch_obj is not None and not isinstance(batch_obj, BatchQuery):
 | 
						|
            raise CQLEngineException('batch_obj must be a BatchQuery instance or None')
 | 
						|
        clone = copy.deepcopy(self)
 | 
						|
        clone._batch = batch_obj
 | 
						|
        return clone
 | 
						|
 | 
						|
    def first(self):
 | 
						|
        try:
 | 
						|
            return iter(self).next()
 | 
						|
        except StopIteration:
 | 
						|
            return None
 | 
						|
 | 
						|
    def all(self):
 | 
						|
        return copy.deepcopy(self)
 | 
						|
 | 
						|
    def _parse_filter_arg(self, arg):
 | 
						|
        """
 | 
						|
        Parses a filter arg in the format:
 | 
						|
        <colname>__<op>
 | 
						|
        :returns: colname, op tuple
 | 
						|
        """
 | 
						|
        statement = arg.rsplit('__', 1)
 | 
						|
        if len(statement) == 1:
 | 
						|
            return arg, None
 | 
						|
        elif len(statement) == 2:
 | 
						|
            return statement[0], statement[1]
 | 
						|
        else:
 | 
						|
            raise QueryException("Can't parse '{}'".format(arg))
 | 
						|
 | 
						|
    def filter(self, *args, **kwargs):
 | 
						|
        """
 | 
						|
        Adds WHERE arguments to the queryset, returning a new queryset
 | 
						|
 | 
						|
        #TODO: show examples
 | 
						|
 | 
						|
        :rtype: AbstractQuerySet
 | 
						|
        """
 | 
						|
        #add arguments to the where clause filters
 | 
						|
        clone = copy.deepcopy(self)
 | 
						|
        for operator in args:
 | 
						|
            if not isinstance(operator, QueryOperator):
 | 
						|
                raise QueryException('{} is not a valid query operator'.format(operator))
 | 
						|
            clone._where.append(operator)
 | 
						|
 | 
						|
        for arg, val in kwargs.items():
 | 
						|
            col_name, col_op = self._parse_filter_arg(arg)
 | 
						|
            #resolve column and operator
 | 
						|
            try:
 | 
						|
                column = self.model._get_column(col_name)
 | 
						|
            except KeyError:
 | 
						|
                if col_name == 'pk__token':
 | 
						|
                    column = columns._PartitionKeysToken(self.model)
 | 
						|
                else:
 | 
						|
                    raise QueryException("Can't resolve column name: '{}'".format(col_name))
 | 
						|
 | 
						|
            #get query operator, or use equals if not supplied
 | 
						|
            operator_class = QueryOperator.get_operator(col_op or 'EQ')
 | 
						|
            operator = operator_class(column, val)
 | 
						|
 | 
						|
            clone._where.append(operator)
 | 
						|
 | 
						|
        return clone
 | 
						|
 | 
						|
    def get(self, *args, **kwargs):
 | 
						|
        """
 | 
						|
        Returns a single instance matching this query, optionally with additional filter kwargs.
 | 
						|
 | 
						|
        A DoesNotExistError will be raised if there are no rows matching the query
 | 
						|
        A MultipleObjectsFoundError will be raised if there is more than one row matching the queyr
 | 
						|
        """
 | 
						|
        if args or kwargs:
 | 
						|
            return self.filter(*args, **kwargs).get()
 | 
						|
 | 
						|
        self._execute_query()
 | 
						|
        if len(self._result_cache) == 0:
 | 
						|
            raise self.model.DoesNotExist
 | 
						|
        elif len(self._result_cache) > 1:
 | 
						|
            raise self.model.MultipleObjectsReturned(
 | 
						|
                    '{} objects found'.format(len(self._result_cache)))
 | 
						|
        else:
 | 
						|
            return self[0]
 | 
						|
 | 
						|
    def _get_ordering_condition(self, colname):
 | 
						|
        order_type = 'DESC' if colname.startswith('-') else 'ASC'
 | 
						|
        colname = colname.replace('-', '')
 | 
						|
 | 
						|
        return colname, order_type
 | 
						|
 | 
						|
    def order_by(self, *colnames):
 | 
						|
        """
 | 
						|
        orders the result set.
 | 
						|
        ordering can only use clustering columns.
 | 
						|
 | 
						|
        Default order is ascending, prepend a '-' to the column name for descending
 | 
						|
        """
 | 
						|
        if len(colnames) == 0:
 | 
						|
            clone = copy.deepcopy(self)
 | 
						|
            clone._order = []
 | 
						|
            return clone
 | 
						|
 | 
						|
        conditions = []
 | 
						|
        for colname in colnames:
 | 
						|
            conditions.append('"{}" {}'.format(*self._get_ordering_condition(colname)))
 | 
						|
 | 
						|
        clone = copy.deepcopy(self)
 | 
						|
        clone._order.extend(conditions)
 | 
						|
        return clone
 | 
						|
 | 
						|
    def count(self):
 | 
						|
        """ Returns the number of rows matched by this query """
 | 
						|
        if self._batch:
 | 
						|
            raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
 | 
						|
        #TODO: check for previous query execution and return row count if it exists
 | 
						|
        if self._result_cache is None:
 | 
						|
            qs = ['SELECT COUNT(*)']
 | 
						|
            qs += ['FROM {}'.format(self.column_family_name)]
 | 
						|
            if self._where:
 | 
						|
                qs += ['WHERE {}'.format(self._where_clause())]
 | 
						|
            if self._allow_filtering:
 | 
						|
                qs += ['ALLOW FILTERING']
 | 
						|
 | 
						|
            qs = ' '.join(qs)
 | 
						|
 | 
						|
            _, result = execute(qs, self._where_values())
 | 
						|
            return result[0][0]
 | 
						|
        else:
 | 
						|
            return len(self._result_cache)
 | 
						|
 | 
						|
    def limit(self, v):
 | 
						|
        """
 | 
						|
        Sets the limit on the number of results returned
 | 
						|
        CQL has a default limit of 10,000
 | 
						|
        """
 | 
						|
        if not (v is None or isinstance(v, (int, long))):
 | 
						|
            raise TypeError
 | 
						|
        if v == self._limit:
 | 
						|
            return self
 | 
						|
 | 
						|
        if v < 0:
 | 
						|
            raise QueryException("Negative limit is not allowed")
 | 
						|
 | 
						|
        clone = copy.deepcopy(self)
 | 
						|
        clone._limit = v
 | 
						|
        return clone
 | 
						|
 | 
						|
    def allow_filtering(self):
 | 
						|
        """
 | 
						|
        Enables the unwise practive of querying on a clustering
 | 
						|
        key without also defining a partition key
 | 
						|
        """
 | 
						|
        clone = copy.deepcopy(self)
 | 
						|
        clone._allow_filtering = True
 | 
						|
        return clone
 | 
						|
 | 
						|
    def _only_or_defer(self, action, fields):
 | 
						|
        clone = copy.deepcopy(self)
 | 
						|
        if clone._defer_fields or clone._only_fields:
 | 
						|
            raise QueryException("QuerySet alread has only or defer fields defined")
 | 
						|
 | 
						|
        #check for strange fields
 | 
						|
        missing_fields = [f for f in fields if f not in self.model._columns.keys()]
 | 
						|
        if missing_fields:
 | 
						|
            raise QueryException(
 | 
						|
                "Can't resolve fields {} in {}".format(
 | 
						|
                    ', '.join(missing_fields), self.model.__name__))
 | 
						|
 | 
						|
        if action == 'defer':
 | 
						|
            clone._defer_fields = fields
 | 
						|
        elif action == 'only':
 | 
						|
            clone._only_fields = fields
 | 
						|
        else:
 | 
						|
            raise ValueError
 | 
						|
 | 
						|
        return clone
 | 
						|
 | 
						|
    def only(self, fields):
 | 
						|
        """ Load only these fields for the returned query """
 | 
						|
        return self._only_or_defer('only', fields)
 | 
						|
 | 
						|
    def defer(self, fields):
 | 
						|
        """ Don't load these fields for the returned query """
 | 
						|
        return self._only_or_defer('defer', fields)
 | 
						|
 | 
						|
    def create(self, **kwargs):
 | 
						|
        return self.model(**kwargs).batch(self._batch).ttl(self._ttl).consistency(self._consistency).save()
 | 
						|
 | 
						|
    #----delete---
 | 
						|
    def delete(self, columns=[]):
 | 
						|
        """
 | 
						|
        Deletes the contents of a query
 | 
						|
        """
 | 
						|
        #validate where clause
 | 
						|
        partition_key = self.model._primary_keys.values()[0]
 | 
						|
        if not any([c.column.db_field_name == partition_key.db_field_name for c in self._where]):
 | 
						|
            raise QueryException("The partition key must be defined on delete queries")
 | 
						|
        qs = ['DELETE FROM {}'.format(self.column_family_name)]
 | 
						|
        qs += ['WHERE {}'.format(self._where_clause())]
 | 
						|
        qs = ' '.join(qs)
 | 
						|
 | 
						|
        if self._batch:
 | 
						|
            self._batch.add_query(qs, self._where_values())
 | 
						|
        else:
 | 
						|
            execute(qs, self._where_values())
 | 
						|
 | 
						|
    def __eq__(self, q):
 | 
						|
        return set(self._where) == set(q._where)
 | 
						|
 | 
						|
    def __ne__(self, q):
 | 
						|
        return not (self != q)
 | 
						|
 | 
						|
 | 
						|
class ResultObject(dict):
 | 
						|
    """
 | 
						|
    adds attribute access to a dictionary
 | 
						|
    """
 | 
						|
 | 
						|
    def __getattr__(self, item):
 | 
						|
        try:
 | 
						|
            return self[item]
 | 
						|
        except KeyError:
 | 
						|
            raise AttributeError
 | 
						|
 | 
						|
 | 
						|
class SimpleQuerySet(AbstractQuerySet):
 | 
						|
    """
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    def _get_select_statement(self):
 | 
						|
        """ Returns the fields to be returned by the select query """
 | 
						|
        return 'SELECT *'
 | 
						|
 | 
						|
    def _get_result_constructor(self, names):
 | 
						|
        """
 | 
						|
        Returns a function that will be used to instantiate query results
 | 
						|
        """
 | 
						|
        def _construct_instance(values):
 | 
						|
            return ResultObject(zip(names, values))
 | 
						|
        return _construct_instance
 | 
						|
 | 
						|
 | 
						|
class ModelQuerySet(AbstractQuerySet):
 | 
						|
    """
 | 
						|
 | 
						|
    """
 | 
						|
    def _validate_where_syntax(self):
 | 
						|
        """ Checks that a filterset will not create invalid cql """
 | 
						|
 | 
						|
        #check that there's either a = or IN relationship with a primary key or indexed field
 | 
						|
        equal_ops = [w for w in self._where if isinstance(w, EqualsOperator)]
 | 
						|
        token_ops = [w for w in self._where if isinstance(w.value, Token)]
 | 
						|
        if not any([w.column.primary_key or w.column.index for w in equal_ops]) and not token_ops:
 | 
						|
            raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field')
 | 
						|
 | 
						|
        if not self._allow_filtering:
 | 
						|
            #if the query is not on an indexed field
 | 
						|
            if not any([w.column.index for w in equal_ops]):
 | 
						|
                if not any([w.column.partition_key for w in equal_ops]) and not token_ops:
 | 
						|
                    raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset')
 | 
						|
            if any(not w.column.partition_key for w in token_ops):
 | 
						|
                raise QueryException('The token() function is only supported on the partition key')
 | 
						|
 | 
						|
    def _where_clause(self):
 | 
						|
        """ Returns a where clause based on the given filter args """
 | 
						|
        self._validate_where_syntax()
 | 
						|
        return super(ModelQuerySet, self)._where_clause()
 | 
						|
 | 
						|
    def _get_select_statement(self):
 | 
						|
        """ Returns the fields to be returned by the select query """
 | 
						|
        if self._defer_fields or self._only_fields:
 | 
						|
            fields = self.model._columns.keys()
 | 
						|
            if self._defer_fields:
 | 
						|
                fields = [f for f in fields if f not in self._defer_fields]
 | 
						|
            elif self._only_fields:
 | 
						|
                fields = self._only_fields
 | 
						|
            db_fields = [self.model._columns[f].db_field_name for f in fields]
 | 
						|
            return 'SELECT {}'.format(', '.join(['"{}"'.format(f) for f in db_fields]))
 | 
						|
        else:
 | 
						|
            return 'SELECT *'
 | 
						|
 | 
						|
    def _get_result_constructor(self, names):
 | 
						|
        """ Returns a function that will be used to instantiate query results """
 | 
						|
        if not self._values_list:
 | 
						|
            return lambda values: self.model._construct_instance(names, values)
 | 
						|
        else:
 | 
						|
            columns = [self.model._columns[n] for n in names]
 | 
						|
            if self._flat_values_list:
 | 
						|
                return lambda values: columns[0].to_python(values[0])
 | 
						|
            else:
 | 
						|
                return lambda values: map(lambda (c, v): c.to_python(v), zip(columns, values))
 | 
						|
 | 
						|
    def _get_ordering_condition(self, colname):
 | 
						|
        colname, order_type = super(ModelQuerySet, self)._get_ordering_condition(colname)
 | 
						|
 | 
						|
        column = self.model._columns.get(colname)
 | 
						|
        if column is None:
 | 
						|
            raise QueryException("Can't resolve the column name: '{}'".format(colname))
 | 
						|
 | 
						|
        #validate the column selection
 | 
						|
        if not column.primary_key:
 | 
						|
            raise QueryException(
 | 
						|
                "Can't order on '{}', can only order on (clustered) primary keys".format(colname))
 | 
						|
 | 
						|
        pks = [v for k, v in self.model._columns.items() if v.primary_key]
 | 
						|
        if column == pks[0]:
 | 
						|
            raise QueryException(
 | 
						|
                "Can't order by the first primary key (partition key), clustering (secondary) keys only")
 | 
						|
 | 
						|
        return column.db_field_name, order_type
 | 
						|
 | 
						|
    def _get_ttl_statement(self):
 | 
						|
        if not self._ttl:
 | 
						|
            return ""
 | 
						|
        return "USING TTL {}".format(self._ttl)
 | 
						|
 | 
						|
    def values_list(self, *fields, **kwargs):
 | 
						|
        """ Instructs the query set to return tuples, not model instance """
 | 
						|
        flat = kwargs.pop('flat', False)
 | 
						|
        if kwargs:
 | 
						|
            raise TypeError('Unexpected keyword arguments to values_list: %s'
 | 
						|
                            % (kwargs.keys(),))
 | 
						|
        if flat and len(fields) > 1:
 | 
						|
            raise TypeError("'flat' is not valid when values_list is called with more than one field.")
 | 
						|
        clone = self.only(fields)
 | 
						|
        clone._values_list = True
 | 
						|
        clone._flat_values_list = flat
 | 
						|
        return clone
 | 
						|
 | 
						|
    def consistency(self, consistency):
 | 
						|
        self._consistency = consistency
 | 
						|
        return self
 | 
						|
 | 
						|
    def ttl(self, ttl):
 | 
						|
        self._ttl = ttl
 | 
						|
        return self
 | 
						|
 | 
						|
    def update(self, **values):
 | 
						|
        """ Updates the rows in this queryset """
 | 
						|
        if not values:
 | 
						|
            return
 | 
						|
 | 
						|
        set_statements = []
 | 
						|
        ctx = {}
 | 
						|
        nulled_columns = set()
 | 
						|
        for name, val in values.items():
 | 
						|
            col = self.model._columns.get(name)
 | 
						|
            # check for nonexistant columns
 | 
						|
            if col is None:
 | 
						|
                raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, name))
 | 
						|
            # check for primary key update attempts
 | 
						|
            if col.is_primary_key:
 | 
						|
                raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(name, self.__module__, self.model.__name__))
 | 
						|
 | 
						|
            val = col.validate(val)
 | 
						|
            if val is None:
 | 
						|
                nulled_columns.add(name)
 | 
						|
                continue
 | 
						|
            # add the update statements
 | 
						|
            if isinstance(col, (BaseContainerColumn, Counter)):
 | 
						|
                val_mgr = self.instance._values[name]
 | 
						|
                set_statements += col.get_update_statement(val, val_mgr.previous_value, ctx)
 | 
						|
 | 
						|
            else:
 | 
						|
                field_id = uuid4().hex
 | 
						|
                set_statements += ['"{}" = :{}'.format(col.db_field_name, field_id)]
 | 
						|
                ctx[field_id] = val
 | 
						|
 | 
						|
        if set_statements:
 | 
						|
            qs = "UPDATE {} SET {} WHERE {}".format(
 | 
						|
                self.column_family_name,
 | 
						|
                ', '.join(set_statements),
 | 
						|
                self._where_clause()
 | 
						|
            )
 | 
						|
            ctx.update(self._where_values())
 | 
						|
            execute(qs, ctx, self._consistency)
 | 
						|
 | 
						|
        if nulled_columns:
 | 
						|
            qs = "DELETE {} FROM {} WHERE {}".format(
 | 
						|
                ', '.join(nulled_columns),
 | 
						|
                self.column_family_name,
 | 
						|
                self._where_clause()
 | 
						|
            )
 | 
						|
            execute(qs, self._where_values(), self._consistency)
 | 
						|
 | 
						|
 | 
						|
class DMLQuery(object):
 | 
						|
    """
 | 
						|
    A query object used for queries performing inserts, updates, or deletes
 | 
						|
 | 
						|
    this is usually instantiated by the model instance to be modified
 | 
						|
 | 
						|
    unlike the read query object, this is mutable
 | 
						|
    """
 | 
						|
    _ttl = None
 | 
						|
    _consistency = None
 | 
						|
 | 
						|
    def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None):
 | 
						|
        self.model = model
 | 
						|
        self.column_family_name = self.model.column_family_name()
 | 
						|
        self.instance = instance
 | 
						|
        self._batch = batch
 | 
						|
        self._ttl = ttl
 | 
						|
        self._consistency = consistency
 | 
						|
 | 
						|
    def batch(self, batch_obj):
 | 
						|
        if batch_obj is not None and not isinstance(batch_obj, BatchQuery):
 | 
						|
            raise CQLEngineException('batch_obj must be a BatchQuery instance or None')
 | 
						|
        self._batch = batch_obj
 | 
						|
        return self
 | 
						|
 | 
						|
    def _delete_null_columns(self):
 | 
						|
        """
 | 
						|
        executes a delete query to remove columns that have changed to null
 | 
						|
        """
 | 
						|
        values, field_names, field_ids, field_values, query_values = self._get_query_values()
 | 
						|
 | 
						|
        # delete nulled columns and removed map keys
 | 
						|
        qs = ['DELETE']
 | 
						|
        query_values = {}
 | 
						|
 | 
						|
        del_statements = []
 | 
						|
        for k,v in self.instance._values.items():
 | 
						|
            col = v.column
 | 
						|
            if v.deleted:
 | 
						|
                del_statements += ['"{}"'.format(col.db_field_name)]
 | 
						|
            elif isinstance(col, Map):
 | 
						|
                del_statements += col.get_delete_statement(v.value, v.previous_value, query_values)
 | 
						|
 | 
						|
        if del_statements:
 | 
						|
            qs += [', '.join(del_statements)]
 | 
						|
 | 
						|
            qs += ['FROM {}'.format(self.column_family_name)]
 | 
						|
 | 
						|
            qs += ['WHERE']
 | 
						|
            where_statements = []
 | 
						|
            for name, col in self.model._primary_keys.items():
 | 
						|
                field_id = uuid4().hex
 | 
						|
                query_values[field_id] = field_values[name]
 | 
						|
                where_statements += ['"{}" = :{}'.format(col.db_field_name, field_id)]
 | 
						|
            qs += [' AND '.join(where_statements)]
 | 
						|
 | 
						|
            qs = ' '.join(qs)
 | 
						|
 | 
						|
            if self._batch:
 | 
						|
                self._batch.add_query(qs, query_values)
 | 
						|
            else:
 | 
						|
                execute(qs, query_values)
 | 
						|
 | 
						|
    def update(self):
 | 
						|
        """
 | 
						|
        updates a row.
 | 
						|
        This is a blind update call.
 | 
						|
        All validation and cleaning needs to happen
 | 
						|
        prior to calling this.
 | 
						|
        """
 | 
						|
        if self.instance is None:
 | 
						|
            raise CQLEngineException("DML Query intance attribute is None")
 | 
						|
        assert type(self.instance) == self.model
 | 
						|
 | 
						|
        values, field_names, field_ids, field_values, query_values = self._get_query_values()
 | 
						|
 | 
						|
        qs = []
 | 
						|
        qs += ["UPDATE {}".format(self.column_family_name)]
 | 
						|
        qs += ["SET"]
 | 
						|
 | 
						|
        set_statements = []
 | 
						|
        #get defined fields and their column names
 | 
						|
        for name, col in self.model._columns.items():
 | 
						|
            if not col.is_primary_key:
 | 
						|
                val = values.get(name)
 | 
						|
 | 
						|
                # don't update something that is null
 | 
						|
                if val is None:
 | 
						|
                    continue
 | 
						|
 | 
						|
                # don't update something if it hasn't changed
 | 
						|
                if not self.instance._values[name].changed and not isinstance(col, Counter):
 | 
						|
                    continue
 | 
						|
 | 
						|
                # add the update statements
 | 
						|
                if isinstance(col, (BaseContainerColumn, Counter)):
 | 
						|
                    #remove value from query values, the column will handle it
 | 
						|
                    query_values.pop(field_ids.get(name), None)
 | 
						|
 | 
						|
                    val_mgr = self.instance._values[name]
 | 
						|
                    set_statements += col.get_update_statement(val, val_mgr.previous_value, query_values)
 | 
						|
 | 
						|
                else:
 | 
						|
                    set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
 | 
						|
        qs += [', '.join(set_statements)]
 | 
						|
 | 
						|
        qs += ['WHERE']
 | 
						|
 | 
						|
        where_statements = []
 | 
						|
        for name, col in self.model._primary_keys.items():
 | 
						|
            where_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
 | 
						|
 | 
						|
        qs += [' AND '.join(where_statements)]
 | 
						|
 | 
						|
        if self._ttl:
 | 
						|
            qs += ["USING TTL {}".format(self._ttl)]
 | 
						|
 | 
						|
        # clear the qs if there are no set statements and this is not a counter model
 | 
						|
        if not set_statements and not self.instance._has_counter:
 | 
						|
            qs = []
 | 
						|
 | 
						|
        qs = ' '.join(qs)
 | 
						|
        # skip query execution if it's empty
 | 
						|
        # caused by pointless update queries
 | 
						|
        if qs:
 | 
						|
            if self._batch:
 | 
						|
                self._batch.add_query(qs, query_values)
 | 
						|
            else:
 | 
						|
                execute(qs, query_values, consistency_level=self._consistency)
 | 
						|
 | 
						|
        self._delete_null_columns()
 | 
						|
 | 
						|
    def _get_query_values(self):
 | 
						|
        """
 | 
						|
        returns all the data needed to do queries
 | 
						|
        """
 | 
						|
        #organize data
 | 
						|
        value_pairs = []
 | 
						|
        values = self.instance._as_dict()
 | 
						|
 | 
						|
        #get defined fields and their column names
 | 
						|
        for name, col in self.model._columns.items():
 | 
						|
            val = values.get(name)
 | 
						|
            if col._val_is_null(val): continue
 | 
						|
            value_pairs += [(col.db_field_name, val)]
 | 
						|
 | 
						|
        #construct query string
 | 
						|
        field_names = zip(*value_pairs)[0]
 | 
						|
        field_ids = {n:uuid4().hex for n in field_names}
 | 
						|
        field_values = dict(value_pairs)
 | 
						|
        query_values = {field_ids[n]:field_values[n] for n in field_names}
 | 
						|
        return values, field_names, field_ids, field_values, query_values
 | 
						|
 | 
						|
    def save(self):
 | 
						|
        """
 | 
						|
        Creates / updates a row.
 | 
						|
        This is a blind insert call.
 | 
						|
        All validation and cleaning needs to happen
 | 
						|
        prior to calling this.
 | 
						|
        """
 | 
						|
        if self.instance is None:
 | 
						|
            raise CQLEngineException("DML Query intance attribute is None")
 | 
						|
        assert type(self.instance) == self.model
 | 
						|
 | 
						|
        values, field_names, field_ids, field_values, query_values = self._get_query_values()
 | 
						|
 | 
						|
        qs = []
 | 
						|
        if self.instance._has_counter or self.instance._can_update():
 | 
						|
            return self.update()
 | 
						|
        else:
 | 
						|
            qs += ["INSERT INTO {}".format(self.column_family_name)]
 | 
						|
            qs += ["({})".format(', '.join(['"{}"'.format(f) for f in field_names]))]
 | 
						|
            qs += ['VALUES']
 | 
						|
            qs += ["({})".format(', '.join([':'+field_ids[f] for f in field_names]))]
 | 
						|
 | 
						|
        if self._ttl:
 | 
						|
            qs += ["USING TTL {}".format(self._ttl)]
 | 
						|
 | 
						|
        qs += []
 | 
						|
        qs = ' '.join(qs)
 | 
						|
 | 
						|
 | 
						|
        # skip query execution if it's empty
 | 
						|
        # caused by pointless update queries
 | 
						|
        if qs:
 | 
						|
            if self._batch:
 | 
						|
                self._batch.add_query(qs, query_values)
 | 
						|
            else:
 | 
						|
                execute(qs, query_values, self._consistency)
 | 
						|
 | 
						|
        # delete any nulled columns
 | 
						|
        self._delete_null_columns()
 | 
						|
 | 
						|
    def delete(self):
 | 
						|
        """ Deletes one instance """
 | 
						|
        if self.instance is None:
 | 
						|
            raise CQLEngineException("DML Query intance attribute is None")
 | 
						|
        field_values = {}
 | 
						|
        qs = ['DELETE FROM {}'.format(self.column_family_name)]
 | 
						|
        qs += ['WHERE']
 | 
						|
        where_statements = []
 | 
						|
        for name, col in self.model._primary_keys.items():
 | 
						|
            field_id = uuid4().hex
 | 
						|
            field_values[field_id] = col.to_database(getattr(self.instance, name))
 | 
						|
            where_statements += ['"{}" = :{}'.format(col.db_field_name, field_id)]
 | 
						|
 | 
						|
        qs += [' AND '.join(where_statements)]
 | 
						|
        qs = ' '.join(qs)
 | 
						|
 | 
						|
        if self._batch:
 | 
						|
            self._batch.add_query(qs, field_values)
 | 
						|
        else:
 | 
						|
            execute(qs, field_values, self._consistency)
 | 
						|
 | 
						|
 |