884 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			884 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2013-2016 DataStax, Inc.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| # http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| from datetime import datetime, timedelta
 | |
| from itertools import ifilter
 | |
| import time
 | |
| import six
 | |
| 
 | |
| from cassandra.query import FETCH_SIZE_UNSET
 | |
| from cassandra.cqlengine import columns
 | |
| from cassandra.cqlengine import UnicodeMixin
 | |
| from cassandra.cqlengine.functions import QueryValue
 | |
| from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator
 | |
| 
 | |
| 
 | |
| class StatementException(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class ValueQuoter(UnicodeMixin):
 | |
| 
 | |
|     def __init__(self, value):
 | |
|         self.value = value
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         from cassandra.encoder import cql_quote
 | |
|         if isinstance(self.value, bool):
 | |
|             return 'true' if self.value else 'false'
 | |
|         elif isinstance(self.value, (list, tuple)):
 | |
|             return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']'
 | |
|         elif isinstance(self.value, dict):
 | |
|             return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}'
 | |
|         elif isinstance(self.value, set):
 | |
|             return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}'
 | |
|         return cql_quote(self.value)
 | |
| 
 | |
|     def __eq__(self, other):
 | |
|         if isinstance(other, self.__class__):
 | |
|             return self.value == other.value
 | |
|         return False
 | |
| 
 | |
| 
 | |
| class InQuoter(ValueQuoter):
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         from cassandra.encoder import cql_quote
 | |
|         return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')'
 | |
| 
 | |
| 
 | |
| class BaseClause(UnicodeMixin):
 | |
| 
 | |
|     def __init__(self, field, value):
 | |
|         self.field = field
 | |
|         self.value = value
 | |
|         self.context_id = None
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def __hash__(self):
 | |
|         return hash(self.field) ^ hash(self.value)
 | |
| 
 | |
|     def __eq__(self, other):
 | |
|         if isinstance(other, self.__class__):
 | |
|             return self.field == other.field and self.value == other.value
 | |
|         return False
 | |
| 
 | |
|     def __ne__(self, other):
 | |
|         return not self.__eq__(other)
 | |
| 
 | |
|     def get_context_size(self):
 | |
|         """ returns the number of entries this clause will add to the query context """
 | |
|         return 1
 | |
| 
 | |
|     def set_context_id(self, i):
 | |
|         """ sets the value placeholder that will be used in the query """
 | |
|         self.context_id = i
 | |
| 
 | |
|     def update_context(self, ctx):
 | |
|         """ updates the query context with this clauses values """
 | |
|         assert isinstance(ctx, dict)
 | |
|         ctx[str(self.context_id)] = self.value
 | |
| 
 | |
| 
 | |
| class WhereClause(BaseClause):
 | |
|     """ a single where statement used in queries """
 | |
| 
 | |
|     def __init__(self, field, operator, value, quote_field=True):
 | |
|         """
 | |
| 
 | |
|         :param field:
 | |
|         :param operator:
 | |
|         :param value:
 | |
|         :param quote_field: hack to get the token function rendering properly
 | |
|         :return:
 | |
|         """
 | |
|         if not isinstance(operator, BaseWhereOperator):
 | |
|             raise StatementException(
 | |
|                 "operator must be of type {0}, got {1}".format(BaseWhereOperator, type(operator))
 | |
|             )
 | |
|         super(WhereClause, self).__init__(field, value)
 | |
|         self.operator = operator
 | |
|         self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value)
 | |
|         self.quote_field = quote_field
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         field = ('"{0}"' if self.quote_field else '{0}').format(self.field)
 | |
|         return u'{0} {1} {2}'.format(field, self.operator, six.text_type(self.query_value))
 | |
| 
 | |
|     def __hash__(self):
 | |
|         return super(WhereClause, self).__hash__() ^ hash(self.operator)
 | |
| 
 | |
|     def __eq__(self, other):
 | |
|         if super(WhereClause, self).__eq__(other):
 | |
|             return self.operator.__class__ == other.operator.__class__
 | |
|         return False
 | |
| 
 | |
|     def get_context_size(self):
 | |
|         return self.query_value.get_context_size()
 | |
| 
 | |
|     def set_context_id(self, i):
 | |
|         super(WhereClause, self).set_context_id(i)
 | |
|         self.query_value.set_context_id(i)
 | |
| 
 | |
|     def update_context(self, ctx):
 | |
|         if isinstance(self.operator, InOperator):
 | |
|             ctx[str(self.context_id)] = InQuoter(self.value)
 | |
|         else:
 | |
|             self.query_value.update_context(ctx)
 | |
| 
 | |
| 
 | |
| class AssignmentClause(BaseClause):
 | |
|     """ a single variable st statement """
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         return u'"{0}" = %({1})s'.format(self.field, self.context_id)
 | |
| 
 | |
|     def insert_tuple(self):
 | |
|         return self.field, self.context_id
 | |
| 
 | |
| 
 | |
| class ConditionalClause(BaseClause):
 | |
|     """ A single variable iff statement """
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         return u'"{0}" = %({1})s'.format(self.field, self.context_id)
 | |
| 
 | |
|     def insert_tuple(self):
 | |
|         return self.field, self.context_id
 | |
| 
 | |
| 
 | |
| class ContainerUpdateTypeMapMeta(type):
 | |
| 
 | |
|     def __init__(cls, name, bases, dct):
 | |
|         if not hasattr(cls, 'type_map'):
 | |
|             cls.type_map = {}
 | |
|         else:
 | |
|             cls.type_map[cls.col_type] = cls
 | |
|         super(ContainerUpdateTypeMapMeta, cls).__init__(name, bases, dct)
 | |
| 
 | |
| 
 | |
| @six.add_metaclass(ContainerUpdateTypeMapMeta)
 | |
| class ContainerUpdateClause(AssignmentClause):
 | |
| 
 | |
|     def __init__(self, field, value, operation=None, previous=None):
 | |
|         super(ContainerUpdateClause, self).__init__(field, value)
 | |
|         self.previous = previous
 | |
|         self._assignments = None
 | |
|         self._operation = operation
 | |
|         self._analyzed = False
 | |
| 
 | |
|     def _analyze(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def get_context_size(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def update_context(self, ctx):
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| class SetUpdateClause(ContainerUpdateClause):
 | |
|     """ updates a set collection """
 | |
| 
 | |
|     col_type = columns.Set
 | |
| 
 | |
|     _additions = None
 | |
|     _removals = None
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         qs = []
 | |
|         ctx_id = self.context_id
 | |
|         if (self.previous is None and
 | |
|                 self._assignments is None and
 | |
|                 self._additions is None and
 | |
|                 self._removals is None):
 | |
|             qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)]
 | |
|         if self._assignments is not None:
 | |
|             qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)]
 | |
|             ctx_id += 1
 | |
|         if self._additions is not None:
 | |
|             qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)]
 | |
|             ctx_id += 1
 | |
|         if self._removals is not None:
 | |
|             qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)]
 | |
| 
 | |
|         return ', '.join(qs)
 | |
| 
 | |
|     def _analyze(self):
 | |
|         """ works out the updates to be performed """
 | |
|         if self.value is None or self.value == self.previous:
 | |
|             pass
 | |
|         elif self._operation == "add":
 | |
|             self._additions = self.value
 | |
|         elif self._operation == "remove":
 | |
|             self._removals = self.value
 | |
|         elif self.previous is None:
 | |
|             self._assignments = self.value
 | |
|         else:
 | |
|             # partial update time
 | |
|             self._additions = (self.value - self.previous) or None
 | |
|             self._removals = (self.previous - self.value) or None
 | |
|         self._analyzed = True
 | |
| 
 | |
|     def get_context_size(self):
 | |
|         if not self._analyzed:
 | |
|             self._analyze()
 | |
|         if (self.previous is None and
 | |
|                 not self._assignments and
 | |
|                 self._additions is None and
 | |
|                 self._removals is None):
 | |
|             return 1
 | |
|         return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals))
 | |
| 
 | |
|     def update_context(self, ctx):
 | |
|         if not self._analyzed:
 | |
|             self._analyze()
 | |
|         ctx_id = self.context_id
 | |
|         if (self.previous is None and
 | |
|                 self._assignments is None and
 | |
|                 self._additions is None and
 | |
|                 self._removals is None):
 | |
|             ctx[str(ctx_id)] = set()
 | |
|         if self._assignments is not None:
 | |
|             ctx[str(ctx_id)] = self._assignments
 | |
|             ctx_id += 1
 | |
|         if self._additions is not None:
 | |
|             ctx[str(ctx_id)] = self._additions
 | |
|             ctx_id += 1
 | |
|         if self._removals is not None:
 | |
|             ctx[str(ctx_id)] = self._removals
 | |
| 
 | |
| 
 | |
| class ListUpdateClause(ContainerUpdateClause):
 | |
|     """ updates a list collection """
 | |
| 
 | |
|     col_type = columns.List
 | |
| 
 | |
|     _append = None
 | |
|     _prepend = None
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         if not self._analyzed:
 | |
|             self._analyze()
 | |
|         qs = []
 | |
|         ctx_id = self.context_id
 | |
|         if self._assignments is not None:
 | |
|             qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)]
 | |
|             ctx_id += 1
 | |
| 
 | |
|         if self._prepend is not None:
 | |
|             qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)]
 | |
|             ctx_id += 1
 | |
| 
 | |
|         if self._append is not None:
 | |
|             qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)]
 | |
| 
 | |
|         return ', '.join(qs)
 | |
| 
 | |
|     def get_context_size(self):
 | |
|         if not self._analyzed:
 | |
|             self._analyze()
 | |
|         return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend))
 | |
| 
 | |
|     def update_context(self, ctx):
 | |
|         if not self._analyzed:
 | |
|             self._analyze()
 | |
|         ctx_id = self.context_id
 | |
|         if self._assignments is not None:
 | |
|             ctx[str(ctx_id)] = self._assignments
 | |
|             ctx_id += 1
 | |
|         if self._prepend is not None:
 | |
|             ctx[str(ctx_id)] = self._prepend
 | |
|             ctx_id += 1
 | |
|         if self._append is not None:
 | |
|             ctx[str(ctx_id)] = self._append
 | |
| 
 | |
|     def _analyze(self):
 | |
|         """ works out the updates to be performed """
 | |
|         if self.value is None or self.value == self.previous:
 | |
|             pass
 | |
| 
 | |
|         elif self._operation == "append":
 | |
|             self._append = self.value
 | |
| 
 | |
|         elif self._operation == "prepend":
 | |
|             self._prepend = self.value
 | |
| 
 | |
|         elif self.previous is None:
 | |
|             self._assignments = self.value
 | |
| 
 | |
|         elif len(self.value) < len(self.previous):
 | |
|             # if elements have been removed,
 | |
|             # rewrite the whole list
 | |
|             self._assignments = self.value
 | |
| 
 | |
|         elif len(self.previous) == 0:
 | |
|             # if we're updating from an empty
 | |
|             # list, do a complete insert
 | |
|             self._assignments = self.value
 | |
|         else:
 | |
| 
 | |
|             # the max start idx we want to compare
 | |
|             search_space = len(self.value) - max(0, len(self.previous) - 1)
 | |
| 
 | |
|             # the size of the sub lists we want to look at
 | |
|             search_size = len(self.previous)
 | |
| 
 | |
|             for i in range(search_space):
 | |
|                 # slice boundary
 | |
|                 j = i + search_size
 | |
|                 sub = self.value[i:j]
 | |
|                 idx_cmp = lambda idx: self.previous[idx] == sub[idx]
 | |
|                 if idx_cmp(0) and idx_cmp(-1) and self.previous == sub:
 | |
|                     self._prepend = self.value[:i] or None
 | |
|                     self._append = self.value[j:] or None
 | |
|                     break
 | |
| 
 | |
|             # if both append and prepend are still None after looking
 | |
|             # at both lists, an insert statement will be created
 | |
|             if self._prepend is self._append is None:
 | |
|                 self._assignments = self.value
 | |
| 
 | |
|         self._analyzed = True
 | |
| 
 | |
| 
 | |
| class MapUpdateClause(ContainerUpdateClause):
 | |
|     """ updates a map collection """
 | |
| 
 | |
|     col_type = columns.Map
 | |
| 
 | |
|     _updates = None
 | |
| 
 | |
|     def _analyze(self):
 | |
|         if self._operation == "update":
 | |
|             self._updates = self.value.keys()
 | |
|         else:
 | |
|             if self.previous is None:
 | |
|                 self._updates = sorted([k for k, v in self.value.items()])
 | |
|             else:
 | |
|                 self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None
 | |
|         self._analyzed = True
 | |
| 
 | |
|     def get_context_size(self):
 | |
|         if self.is_assignment:
 | |
|             return 1
 | |
|         return len(self._updates or []) * 2
 | |
| 
 | |
|     def update_context(self, ctx):
 | |
|         ctx_id = self.context_id
 | |
|         if self.is_assignment:
 | |
|             ctx[str(ctx_id)] = {}
 | |
|         else:
 | |
|             for key in self._updates or []:
 | |
|                 val = self.value.get(key)
 | |
|                 ctx[str(ctx_id)] = key
 | |
|                 ctx[str(ctx_id + 1)] = val
 | |
|                 ctx_id += 2
 | |
| 
 | |
|     @property
 | |
|     def is_assignment(self):
 | |
|         if not self._analyzed:
 | |
|             self._analyze()
 | |
|         return self.previous is None and not self._updates
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         qs = []
 | |
| 
 | |
|         ctx_id = self.context_id
 | |
|         if self.is_assignment:
 | |
|             qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)]
 | |
|         else:
 | |
|             for _ in self._updates or []:
 | |
|                 qs += ['"{0}"[%({1})s] = %({2})s'.format(self.field, ctx_id, ctx_id + 1)]
 | |
|                 ctx_id += 2
 | |
| 
 | |
|         return ', '.join(qs)
 | |
| 
 | |
| 
 | |
| class CounterUpdateClause(AssignmentClause):
 | |
| 
 | |
|     col_type = columns.Counter
 | |
| 
 | |
|     def __init__(self, field, value, previous=None):
 | |
|         super(CounterUpdateClause, self).__init__(field, value)
 | |
|         self.previous = previous or 0
 | |
| 
 | |
|     def get_context_size(self):
 | |
|         return 1
 | |
| 
 | |
|     def update_context(self, ctx):
 | |
|         ctx[str(self.context_id)] = abs(self.value - self.previous)
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         delta = self.value - self.previous
 | |
|         sign = '-' if delta < 0 else '+'
 | |
|         return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id)
 | |
| 
 | |
| 
 | |
| class BaseDeleteClause(BaseClause):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class FieldDeleteClause(BaseDeleteClause):
 | |
|     """ deletes a field from a row """
 | |
| 
 | |
|     def __init__(self, field):
 | |
|         super(FieldDeleteClause, self).__init__(field, None)
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         return '"{0}"'.format(self.field)
 | |
| 
 | |
|     def update_context(self, ctx):
 | |
|         pass
 | |
| 
 | |
|     def get_context_size(self):
 | |
|         return 0
 | |
| 
 | |
| 
 | |
| class MapDeleteClause(BaseDeleteClause):
 | |
|     """ removes keys from a map """
 | |
| 
 | |
|     def __init__(self, field, value, previous=None):
 | |
|         super(MapDeleteClause, self).__init__(field, value)
 | |
|         self.value = self.value or {}
 | |
|         self.previous = previous or {}
 | |
|         self._analyzed = False
 | |
|         self._removals = None
 | |
| 
 | |
|     def _analyze(self):
 | |
|         self._removals = sorted([k for k in self.previous if k not in self.value])
 | |
|         self._analyzed = True
 | |
| 
 | |
|     def update_context(self, ctx):
 | |
|         if not self._analyzed:
 | |
|             self._analyze()
 | |
|         for idx, key in enumerate(self._removals):
 | |
|             ctx[str(self.context_id + idx)] = key
 | |
| 
 | |
|     def get_context_size(self):
 | |
|         if not self._analyzed:
 | |
|             self._analyze()
 | |
|         return len(self._removals)
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         if not self._analyzed:
 | |
|             self._analyze()
 | |
|         return ', '.join(['"{0}"[%({1})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))])
 | |
| 
 | |
| 
 | |
| class BaseCQLStatement(UnicodeMixin):
 | |
|     """ The base cql statement class """
 | |
| 
 | |
|     def __init__(self, table, timestamp=None, where=None, fetch_size=None, conditionals=None):
 | |
|         super(BaseCQLStatement, self).__init__()
 | |
|         self.table = table
 | |
|         self.context_id = 0
 | |
|         self.context_counter = self.context_id
 | |
|         self.timestamp = timestamp
 | |
|         self.fetch_size = fetch_size if fetch_size else FETCH_SIZE_UNSET
 | |
| 
 | |
|         self.where_clauses = []
 | |
|         for clause in where or []:
 | |
|             self._add_where_clause(clause)
 | |
| 
 | |
|         self.conditionals = []
 | |
|         for conditional in conditionals or []:
 | |
|             self.add_conditional_clause(conditional)
 | |
| 
 | |
|     def _update_part_key_values(self, field_index_map, clauses, parts):
 | |
|         for clause in ifilter(lambda c: c.field in field_index_map, clauses):
 | |
|             parts[field_index_map[clause.field]] = clause.value
 | |
| 
 | |
|     def partition_key_values(self, field_index_map):
 | |
|         parts = [None] * len(field_index_map)
 | |
|         self._update_part_key_values(field_index_map, (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), parts)
 | |
|         return parts
 | |
| 
 | |
|     def add_where(self, column, operator, value, quote_field=True):
 | |
|         value = column.to_database(value)
 | |
|         clause = WhereClause(column.db_field_name, operator, value, quote_field)
 | |
|         self._add_where_clause(clause)
 | |
| 
 | |
|     def _add_where_clause(self, clause):
 | |
|         clause.set_context_id(self.context_counter)
 | |
|         self.context_counter += clause.get_context_size()
 | |
|         self.where_clauses.append(clause)
 | |
| 
 | |
|     def get_context(self):
 | |
|         """
 | |
|         returns the context dict for this statement
 | |
|         :rtype: dict
 | |
|         """
 | |
|         ctx = {}
 | |
|         for clause in self.where_clauses or []:
 | |
|             clause.update_context(ctx)
 | |
|         return ctx
 | |
| 
 | |
|     def add_conditional_clause(self, clause):
 | |
|         """
 | |
|         Adds a iff clause to this statement
 | |
| 
 | |
|         :param clause: The clause that will be added to the iff statement
 | |
|         :type clause: ConditionalClause
 | |
|         """
 | |
|         if not isinstance(clause, ConditionalClause):
 | |
|             raise StatementException('only instances of AssignmentClause can be added to statements')
 | |
|         clause.set_context_id(self.context_counter)
 | |
|         self.context_counter += clause.get_context_size()
 | |
|         self.conditionals.append(clause)
 | |
| 
 | |
|     def _get_conditionals(self):
 | |
|         return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.conditionals]))
 | |
| 
 | |
|     def get_context_size(self):
 | |
|         return len(self.get_context())
 | |
| 
 | |
|     def update_context_id(self, i):
 | |
|         self.context_id = i
 | |
|         self.context_counter = self.context_id
 | |
|         for clause in self.where_clauses:
 | |
|             clause.set_context_id(self.context_counter)
 | |
|             self.context_counter += clause.get_context_size()
 | |
| 
 | |
|     @property
 | |
|     def timestamp_normalized(self):
 | |
|         """
 | |
|         we're expecting self.timestamp to be either a long, int, a datetime, or a timedelta
 | |
|         :return:
 | |
|         """
 | |
|         if not self.timestamp:
 | |
|             return None
 | |
| 
 | |
|         if isinstance(self.timestamp, six.integer_types):
 | |
|             return self.timestamp
 | |
| 
 | |
|         if isinstance(self.timestamp, timedelta):
 | |
|             tmp = datetime.now() + self.timestamp
 | |
|         else:
 | |
|             tmp = self.timestamp
 | |
| 
 | |
|         return int(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond)
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return self.__unicode__()
 | |
| 
 | |
|     @property
 | |
|     def _where(self):
 | |
|         return 'WHERE {0}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses]))
 | |
| 
 | |
| 
 | |
| class SelectStatement(BaseCQLStatement):
 | |
|     """ a cql select statement """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  table,
 | |
|                  fields=None,
 | |
|                  count=False,
 | |
|                  where=None,
 | |
|                  order_by=None,
 | |
|                  limit=None,
 | |
|                  allow_filtering=False,
 | |
|                  distinct_fields=None,
 | |
|                  fetch_size=None):
 | |
| 
 | |
|         """
 | |
|         :param where
 | |
|         :type where list of cqlengine.statements.WhereClause
 | |
|         """
 | |
|         super(SelectStatement, self).__init__(
 | |
|             table,
 | |
|             where=where,
 | |
|             fetch_size=fetch_size
 | |
|         )
 | |
| 
 | |
|         self.fields = [fields] if isinstance(fields, six.string_types) else (fields or [])
 | |
|         self.distinct_fields = distinct_fields
 | |
|         self.count = count
 | |
|         self.order_by = [order_by] if isinstance(order_by, six.string_types) else order_by
 | |
|         self.limit = limit
 | |
|         self.allow_filtering = allow_filtering
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         qs = ['SELECT']
 | |
|         if self.distinct_fields:
 | |
|             if self.count:
 | |
|                 qs += ['DISTINCT COUNT({0})'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
 | |
|             else:
 | |
|                 qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
 | |
|         elif self.count:
 | |
|             qs += ['COUNT(*)']
 | |
|         else:
 | |
|             qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*']
 | |
|         qs += ['FROM', self.table]
 | |
| 
 | |
|         if self.where_clauses:
 | |
|             qs += [self._where]
 | |
| 
 | |
|         if self.order_by and not self.count:
 | |
|             qs += ['ORDER BY {0}'.format(', '.join(six.text_type(o) for o in self.order_by))]
 | |
| 
 | |
|         if self.limit:
 | |
|             qs += ['LIMIT {0}'.format(self.limit)]
 | |
| 
 | |
|         if self.allow_filtering:
 | |
|             qs += ['ALLOW FILTERING']
 | |
| 
 | |
|         return ' '.join(qs)
 | |
| 
 | |
| 
 | |
| class AssignmentStatement(BaseCQLStatement):
 | |
|     """ value assignment statements """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  table,
 | |
|                  assignments=None,
 | |
|                  where=None,
 | |
|                  ttl=None,
 | |
|                  timestamp=None,
 | |
|                  conditionals=None):
 | |
|         super(AssignmentStatement, self).__init__(
 | |
|             table,
 | |
|             where=where,
 | |
|             conditionals=conditionals
 | |
|         )
 | |
|         self.ttl = ttl
 | |
|         self.timestamp = timestamp
 | |
| 
 | |
|         # add assignments
 | |
|         self.assignments = []
 | |
|         for assignment in assignments or []:
 | |
|             self._add_assignment_clause(assignment)
 | |
| 
 | |
|     def update_context_id(self, i):
 | |
|         super(AssignmentStatement, self).update_context_id(i)
 | |
|         for assignment in self.assignments:
 | |
|             assignment.set_context_id(self.context_counter)
 | |
|             self.context_counter += assignment.get_context_size()
 | |
| 
 | |
|     def partition_key_values(self, field_index_map):
 | |
|         parts = super(AssignmentStatement, self).partition_key_values(field_index_map)
 | |
|         self._update_part_key_values(field_index_map, self.assignments, parts)
 | |
|         return parts
 | |
| 
 | |
|     def add_assignment(self, column, value):
 | |
|         value = column.to_database(value)
 | |
|         clause = AssignmentClause(column.db_field_name, value)
 | |
|         self._add_assignment_clause(clause)
 | |
| 
 | |
|     def _add_assignment_clause(self, clause):
 | |
|         clause.set_context_id(self.context_counter)
 | |
|         self.context_counter += clause.get_context_size()
 | |
|         self.assignments.append(clause)
 | |
| 
 | |
|     @property
 | |
|     def is_empty(self):
 | |
|         return len(self.assignments) == 0
 | |
| 
 | |
|     def get_context(self):
 | |
|         ctx = super(AssignmentStatement, self).get_context()
 | |
|         for clause in self.assignments:
 | |
|             clause.update_context(ctx)
 | |
|         return ctx
 | |
| 
 | |
| 
 | |
| class InsertStatement(AssignmentStatement):
 | |
|     """ an cql insert statement """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  table,
 | |
|                  assignments=None,
 | |
|                  where=None,
 | |
|                  ttl=None,
 | |
|                  timestamp=None,
 | |
|                  if_not_exists=False):
 | |
|         super(InsertStatement, self).__init__(table,
 | |
|                                               assignments=assignments,
 | |
|                                               where=where,
 | |
|                                               ttl=ttl,
 | |
|                                               timestamp=timestamp)
 | |
| 
 | |
|         self.if_not_exists = if_not_exists
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         qs = ['INSERT INTO {0}'.format(self.table)]
 | |
| 
 | |
|         # get column names and context placeholders
 | |
|         fields = [a.insert_tuple() for a in self.assignments]
 | |
|         columns, values = zip(*fields)
 | |
| 
 | |
|         qs += ["({0})".format(', '.join(['"{0}"'.format(c) for c in columns]))]
 | |
|         qs += ['VALUES']
 | |
|         qs += ["({0})".format(', '.join(['%({0})s'.format(v) for v in values]))]
 | |
| 
 | |
|         if self.if_not_exists:
 | |
|             qs += ["IF NOT EXISTS"]
 | |
| 
 | |
|         if self.ttl:
 | |
|             qs += ["USING TTL {0}".format(self.ttl)]
 | |
| 
 | |
|         if self.timestamp:
 | |
|             qs += ["USING TIMESTAMP {0}".format(self.timestamp_normalized)]
 | |
| 
 | |
|         return ' '.join(qs)
 | |
| 
 | |
| 
 | |
| class UpdateStatement(AssignmentStatement):
 | |
|     """ an cql update select statement """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  table,
 | |
|                  assignments=None,
 | |
|                  where=None,
 | |
|                  ttl=None,
 | |
|                  timestamp=None,
 | |
|                  conditionals=None,
 | |
|                  if_exists=False):
 | |
|         super(UpdateStatement, self). __init__(table,
 | |
|                                                assignments=assignments,
 | |
|                                                where=where,
 | |
|                                                ttl=ttl,
 | |
|                                                timestamp=timestamp,
 | |
|                                                conditionals=conditionals)
 | |
| 
 | |
|         self.if_exists = if_exists
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         qs = ['UPDATE', self.table]
 | |
| 
 | |
|         using_options = []
 | |
| 
 | |
|         if self.ttl:
 | |
|             using_options += ["TTL {0}".format(self.ttl)]
 | |
| 
 | |
|         if self.timestamp:
 | |
|             using_options += ["TIMESTAMP {0}".format(self.timestamp_normalized)]
 | |
| 
 | |
|         if using_options:
 | |
|             qs += ["USING {0}".format(" AND ".join(using_options))]
 | |
| 
 | |
|         qs += ['SET']
 | |
|         qs += [', '.join([six.text_type(c) for c in self.assignments])]
 | |
| 
 | |
|         if self.where_clauses:
 | |
|             qs += [self._where]
 | |
| 
 | |
|         if len(self.conditionals) > 0:
 | |
|             qs += [self._get_conditionals()]
 | |
| 
 | |
|         if self.if_exists:
 | |
|             qs += ["IF EXISTS"]
 | |
| 
 | |
|         return ' '.join(qs)
 | |
| 
 | |
|     def get_context(self):
 | |
|         ctx = super(UpdateStatement, self).get_context()
 | |
|         for clause in self.conditionals:
 | |
|             clause.update_context(ctx)
 | |
|         return ctx
 | |
| 
 | |
|     def update_context_id(self, i):
 | |
|         super(UpdateStatement, self).update_context_id(i)
 | |
|         for conditional in self.conditionals:
 | |
|             conditional.set_context_id(self.context_counter)
 | |
|             self.context_counter += conditional.get_context_size()
 | |
| 
 | |
|     def add_update(self, column, value, operation=None, previous=None):
 | |
|         value = column.to_database(value)
 | |
|         col_type = type(column)
 | |
|         container_update_type = ContainerUpdateClause.type_map.get(col_type)
 | |
|         if container_update_type:
 | |
|             previous = column.to_database(previous)
 | |
|             clause = container_update_type(column.db_field_name, value, operation, previous)
 | |
|         elif col_type == columns.Counter:
 | |
|             clause = CounterUpdateClause(column.db_field_name, value)
 | |
|         else:
 | |
|             clause = AssignmentClause(column.db_field_name, value)
 | |
|         if clause.get_context_size():  # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates
 | |
|             self._add_assignment_clause(clause)
 | |
| 
 | |
| 
 | |
| class DeleteStatement(BaseCQLStatement):
 | |
|     """ a cql delete statement """
 | |
| 
 | |
|     def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=None, if_exists=False):
 | |
|         super(DeleteStatement, self).__init__(
 | |
|             table,
 | |
|             where=where,
 | |
|             timestamp=timestamp,
 | |
|             conditionals=conditionals
 | |
|         )
 | |
|         self.fields = []
 | |
|         if isinstance(fields, six.string_types):
 | |
|             fields = [fields]
 | |
|         for field in fields or []:
 | |
|             self.add_field(field)
 | |
| 
 | |
|         self.if_exists = if_exists
 | |
| 
 | |
|     def update_context_id(self, i):
 | |
|         super(DeleteStatement, self).update_context_id(i)
 | |
|         for field in self.fields:
 | |
|             field.set_context_id(self.context_counter)
 | |
|             self.context_counter += field.get_context_size()
 | |
|         for t in self.conditionals:
 | |
|             t.set_context_id(self.context_counter)
 | |
|             self.context_counter += t.get_context_size()
 | |
| 
 | |
|     def get_context(self):
 | |
|         ctx = super(DeleteStatement, self).get_context()
 | |
|         for field in self.fields:
 | |
|             field.update_context(ctx)
 | |
|         for clause in self.conditionals:
 | |
|             clause.update_context(ctx)
 | |
|         return ctx
 | |
| 
 | |
|     def add_field(self, field):
 | |
|         if isinstance(field, six.string_types):
 | |
|             field = FieldDeleteClause(field)
 | |
|         if not isinstance(field, BaseClause):
 | |
|             raise StatementException("only instances of AssignmentClause can be added to statements")
 | |
|         field.set_context_id(self.context_counter)
 | |
|         self.context_counter += field.get_context_size()
 | |
|         self.fields.append(field)
 | |
| 
 | |
|     def __unicode__(self):
 | |
|         qs = ['DELETE']
 | |
|         if self.fields:
 | |
|             qs += [', '.join(['{0}'.format(f) for f in self.fields])]
 | |
|         qs += ['FROM', self.table]
 | |
| 
 | |
|         delete_option = []
 | |
| 
 | |
|         if self.timestamp:
 | |
|             delete_option += ["TIMESTAMP {0}".format(self.timestamp_normalized)]
 | |
| 
 | |
|         if delete_option:
 | |
|             qs += [" USING {0} ".format(" AND ".join(delete_option))]
 | |
| 
 | |
|         if self.where_clauses:
 | |
|             qs += [self._where]
 | |
| 
 | |
|         if self.conditionals:
 | |
|             qs += [self._get_conditionals()]
 | |
| 
 | |
|         if self.if_exists:
 | |
|             qs += ["IF EXISTS"]
 | |
| 
 | |
|         return ' '.join(qs)
 | 
