refactoring update query

This commit is contained in:
Blake Eggleston
2013-11-03 08:43:05 -08:00
parent 7d98bfb67c
commit 1dd5b50bd7
2 changed files with 17 additions and 26 deletions

View File

@@ -16,7 +16,7 @@ from cqlengine.functions import QueryValue, Token
#http://www.datastax.com/docs/1.1/references/cql/index #http://www.datastax.com/docs/1.1/references/cql/index
from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator
from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause
class QueryException(CQLEngineException): pass class QueryException(CQLEngineException): pass
@@ -767,9 +767,8 @@ class ModelQuerySet(AbstractQuerySet):
if not values: if not values:
return return
set_statements = []
ctx = {}
nulled_columns = set() nulled_columns = set()
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl)
for name, val in values.items(): for name, val in values.items():
col = self.model._columns.get(name) col = self.model._columns.get(name)
# check for nonexistant columns # check for nonexistant columns
@@ -784,39 +783,28 @@ class ModelQuerySet(AbstractQuerySet):
nulled_columns.add(name) nulled_columns.add(name)
continue continue
# add the update statements # add the update statements
if isinstance(col, (BaseContainerColumn, Counter)): if isinstance(col, Counter):
val_mgr = self.instance._values[name] # TODO: implement counter updates
set_statements += col.get_update_statement(val, val_mgr.previous_value, ctx) raise NotImplementedError
else: else:
field_id = uuid4().hex us.add_assignment_clause(AssignmentClause(name, col.to_database(val)))
set_statements += ['"{}" = :{}'.format(col.db_field_name, field_id)]
ctx[field_id] = val
if set_statements: if us.assignments:
ttl_stmt = "USING TTL {}".format(self._ttl) if self._ttl else "" qs = str(us)
qs = "UPDATE {} SET {} WHERE {} {}".format( ctx = us.get_context()
self.column_family_name,
', '.join(set_statements),
self._where_clause(),
ttl_stmt
)
ctx.update(self._where_values())
if self._batch: if self._batch:
self._batch.add_query(qs, ctx) self._batch.add_query(qs, ctx)
else: else:
execute(qs, ctx, self._consistency) execute(qs, ctx, self._consistency)
if nulled_columns: if nulled_columns:
qs = "DELETE {} FROM {} WHERE {}".format( ds = DeleteStatement(self.column_family_name, fields=nulled_columns, where=self._where)
', '.join(nulled_columns), qs = str(ds)
self.column_family_name, ctx = ds.get_context()
self._where_clause()
)
if self._batch: if self._batch:
self._batch.add_query(qs, self._where_values()) self._batch.add_query(qs, ctx)
else: else:
execute(qs, self._where_values(), self._consistency) execute(qs, ctx, self._consistency)
class DMLQuery(object): class DMLQuery(object):

View File

@@ -112,3 +112,6 @@ class QueryUpdateTests(BaseCassEngTestCase):
assert row.cluster == i assert row.cluster == i
assert row.count == (6 if i == 3 else i) assert row.count == (6 if i == 3 else i)
assert row.text == (None if i == 3 else str(i)) assert row.text == (None if i == 3 else str(i))
def test_counter_updates(self):
pass