Merge pull request #508 from datastax/py323

PYTHON-323: Add ability to set query's fetch_size and limit using ORM
This commit is contained in:
Adam Holmberg
2016-03-03 16:20:25 -06:00
6 changed files with 149 additions and 40 deletions

View File

@@ -162,8 +162,7 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET):
elif isinstance(query, BaseCQLStatement):
params = query.get_context()
query = str(query)
query = SimpleStatement(query, consistency_level=consistency_level)
query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
elif isinstance(query, six.string_types):
query = SimpleStatement(query, consistency_level=consistency_level)

View File

@@ -289,8 +289,12 @@ class AbstractQuerySet(object):
# results cache
self._result_cache = None
self._result_idx = None
self._result_generator = None
self._distinct_fields = None
self._count = None
self._batch = None
self._ttl = getattr(model, '__default_ttl__', None)
self._consistency = None
@@ -298,6 +302,7 @@ class AbstractQuerySet(object):
self._if_not_exists = False
self._timeout = connection.NOT_SET
self._if_exists = False
self._fetch_size = None
@property
def column_family_name(self):
@@ -324,7 +329,7 @@ class AbstractQuerySet(object):
def __deepcopy__(self, memo):
clone = self.__class__(self.model)
for k, v in self.__dict__.items():
if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these
if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator']: # don't clone these
clone.__dict__[k] = None
elif k == '_batch':
# we need to keep the same batch instance across
@@ -341,7 +346,7 @@ class AbstractQuerySet(object):
def __len__(self):
self._execute_query()
return len(self._result_cache)
return self.count()
# ----query generation / execution----
@@ -365,7 +370,8 @@ class AbstractQuerySet(object):
order_by=self._order,
limit=self._limit,
allow_filtering=self._allow_filtering,
distinct_fields=self._distinct_fields
distinct_fields=self._distinct_fields,
fetch_size=self._fetch_size
)
# ----Reads------
@@ -374,7 +380,8 @@ class AbstractQuerySet(object):
if self._batch:
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
if self._result_cache is None:
self._result_cache = list(self._execute(self._select_query()))
self._result_generator = (i for i in self._execute(self._select_query()))
self._result_cache = []
self._construct_result = self._get_result_constructor()
def _fill_result_cache_to_idx(self, idx):
@@ -388,44 +395,63 @@ class AbstractQuerySet(object):
else:
for idx in range(qty):
self._result_idx += 1
self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx])
while True:
try:
self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx])
break
except IndexError:
self._result_cache.append(next(self._result_generator))
def __iter__(self):
self._execute_query()
for idx in range(len(self._result_cache)):
idx = 0
while True:
if len(self._result_cache) <= idx:
try:
self._result_cache.append(next(self._result_generator))
except StopIteration:
break
instance = self._result_cache[idx]
if isinstance(instance, dict):
self._fill_result_cache_to_idx(idx)
yield self._result_cache[idx]
idx += 1
def __getitem__(self, s):
self._execute_query()
num_results = len(self._result_cache)
if isinstance(s, slice):
# calculate the amount of results that need to be loaded
end = num_results if s.step is None else s.step
if end < 0:
end += num_results
else:
end -= 1
self._fill_result_cache_to_idx(end)
end = s.stop
if s.start < 0 or s.stop < 0:
end = self.count()
try:
self._fill_result_cache_to_idx(end)
except StopIteration:
pass
return self._result_cache[s.start:s.stop:s.step]
else:
# return the object at this index
s = int(s)
try:
s = int(s)
except (ValueError, TypeError):
raise TypeError('QuerySet indices must be integers')
# handle negative indexing
# Using negative indexing is costly since we have to execute a count()
if s < 0:
num_results = self.count()
s += num_results
if s >= num_results:
raise IndexError
else:
try:
self._fill_result_cache_to_idx(s)
return self._result_cache[s]
except StopIteration:
raise IndexError
return self._result_cache[s]
def _get_result_constructor(self):
"""
@@ -615,10 +641,10 @@ class AbstractQuerySet(object):
return self.filter(*args, **kwargs).get()
self._execute_query()
if len(self._result_cache) == 0:
if self.count() == 0:
raise self.model.DoesNotExist
elif len(self._result_cache) > 1:
raise self.model.MultipleObjectsReturned('{0} objects found'.format(len(self._result_cache)))
elif self.count() > 1:
raise self.model.MultipleObjectsReturned('{0} objects found'.format(self.count()))
else:
return self[0]
@@ -679,13 +705,13 @@ class AbstractQuerySet(object):
if self._batch:
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
if self._result_cache is None:
if self._count is None:
query = self._select_query()
query.count = True
result = self._execute(query)
return result[0]['count']
else:
return len(self._result_cache)
count_row = result[0].popitem()
self._count = count_row[1]
return self._count
def distinct(self, distinct_fields=None):
"""
@@ -734,7 +760,11 @@ class AbstractQuerySet(object):
for user in User.objects().limit(100):
print(user)
"""
if not (v is None or isinstance(v, six.integer_types)):
if v is None:
v = 0
if not isinstance(v, six.integer_types):
raise TypeError
if v == self._limit:
return self
@@ -746,6 +776,30 @@ class AbstractQuerySet(object):
clone._limit = v
return clone
def fetch_size(self, v):
"""
Sets the number of rows that are fetched at a time.
*Note that driver's default fetch size is 5000.
.. code-block:: python
for user in User.objects().fetch_size(500):
print(user)
"""
if not isinstance(v, six.integer_types):
raise TypeError
if v == self._fetch_size:
return self
if v < 1:
raise QueryException("fetch size less than 1 is not allowed")
clone = copy.deepcopy(self)
clone._fetch_size = v
return clone
def allow_filtering(self):
"""
Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key

View File

@@ -16,6 +16,7 @@ from datetime import datetime, timedelta
import time
import six
from cassandra.query import FETCH_SIZE_UNSET
from cassandra.cqlengine import UnicodeMixin
from cassandra.cqlengine.functions import QueryValue
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator
@@ -470,13 +471,14 @@ class MapDeleteClause(BaseDeleteClause):
class BaseCQLStatement(UnicodeMixin):
""" The base cql statement class """
def __init__(self, table, consistency=None, timestamp=None, where=None):
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None):
super(BaseCQLStatement, self).__init__()
self.table = table
self.consistency = consistency
self.context_id = 0
self.context_counter = self.context_id
self.timestamp = timestamp
self.fetch_size = fetch_size if fetch_size else FETCH_SIZE_UNSET
self.where_clauses = []
for clause in where or []:
@@ -556,7 +558,8 @@ class SelectStatement(BaseCQLStatement):
order_by=None,
limit=None,
allow_filtering=False,
distinct_fields=None):
distinct_fields=None,
fetch_size=None):
"""
:param where
@@ -565,7 +568,8 @@ class SelectStatement(BaseCQLStatement):
super(SelectStatement, self).__init__(
table,
consistency=consistency,
where=where
where=where,
fetch_size=fetch_size
)
self.fields = [fields] if isinstance(fields, six.string_types) else (fields or [])
@@ -577,10 +581,13 @@ class SelectStatement(BaseCQLStatement):
def __unicode__(self):
qs = ['SELECT']
if self.count:
if self.distinct_fields:
if self.count:
qs += ['DISTINCT COUNT({0})'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
else:
qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
elif self.count:
qs += ['COUNT(*)']
elif self.distinct_fields:
qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
else:
qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*']
qs += ['FROM', self.table]

View File

@@ -485,6 +485,13 @@ class TestQuerySetDistinct(BaseQuerySetUsage):
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[52])
self.assertEqual(len(q), 0)
def test_distinct_with_explicit_count(self):
q = TestModel.objects.distinct(['test_id'])
self.assertEqual(q.count(), 3)
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2])
self.assertEqual(q.count(), 2)
class TestQuerySetOrdering(BaseQuerySetUsage):
@@ -554,16 +561,31 @@ class TestQuerySetSlicing(BaseQuerySetUsage):
def test_slicing_works_properly(self):
q = TestModel.objects(test_id=0).order_by('attempt_id')
expected_order = [0, 1, 2, 3]
for model, expect in zip(q[1:3], expected_order[1:3]):
assert model.attempt_id == expect
self.assertEqual(model.attempt_id, expect)
for model, expect in zip(q[0:3:2], expected_order[0:3:2]):
self.assertEqual(model.attempt_id, expect)
def test_negative_slicing(self):
q = TestModel.objects(test_id=0).order_by('attempt_id')
expected_order = [0, 1, 2, 3]
for model, expect in zip(q[-3:], expected_order[-3:]):
assert model.attempt_id == expect
self.assertEqual(model.attempt_id, expect)
for model, expect in zip(q[:-1], expected_order[:-1]):
assert model.attempt_id == expect
self.assertEqual(model.attempt_id, expect)
for model, expect in zip(q[1:-1], expected_order[1:-1]):
self.assertEqual(model.attempt_id, expect)
for model, expect in zip(q[-3:-1], expected_order[-3:-1]):
self.assertEqual(model.attempt_id, expect)
for model, expect in zip(q[-3:-1:2], expected_order[-3:-1:2]):
self.assertEqual(model.attempt_id, expect)
class TestQuerySetValidation(BaseQuerySetUsage):

View File

@@ -16,6 +16,7 @@ try:
except ImportError:
import unittest # noqa
from cassandra.query import FETCH_SIZE_UNSET
from cassandra.cqlengine.statements import BaseCQLStatement, StatementException
@@ -26,3 +27,14 @@ class BaseStatementTest(unittest.TestCase):
stmt = BaseCQLStatement('table', [])
with self.assertRaises(StatementException):
stmt.add_where_clause('x=5')
def test_fetch_size(self):
""" tests that fetch_size is correctly set """
stmt = BaseCQLStatement('table', None, fetch_size=1000)
self.assertEqual(stmt.fetch_size, 1000)
stmt = BaseCQLStatement('table', None, fetch_size=None)
self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET)
stmt = BaseCQLStatement('table', None)
self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET)

View File

@@ -64,6 +64,9 @@ class SelectStatementTests(unittest.TestCase):
ss = SelectStatement('table', distinct_fields=['field1', 'field2'])
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field1", "field2" FROM table')
ss = SelectStatement('table', distinct_fields=['field1'], count=True)
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT COUNT("field1") FROM table')
def test_context(self):
ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
@@ -93,3 +96,15 @@ class SelectStatementTests(unittest.TestCase):
self.assertIn('ORDER BY x, y', qstr)
self.assertIn('ALLOW FILTERING', qstr)
def test_limit_rendering(self):
ss = SelectStatement('table', None, limit=10)
qstr = six.text_type(ss)
self.assertIn('LIMIT 10', qstr)
ss = SelectStatement('table', None, limit=0)
qstr = six.text_type(ss)
self.assertNotIn('LIMIT', qstr)
ss = SelectStatement('table', None, limit=None)
qstr = six.text_type(ss)
self.assertNotIn('LIMIT', qstr)