Merge pull request #508 from datastax/py323
PYTHON-323: Add ability to set query's fetch_size and limit using ORM
This commit is contained in:
@@ -162,8 +162,7 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET):
|
||||
|
||||
elif isinstance(query, BaseCQLStatement):
|
||||
params = query.get_context()
|
||||
query = str(query)
|
||||
query = SimpleStatement(query, consistency_level=consistency_level)
|
||||
query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
|
||||
|
||||
elif isinstance(query, six.string_types):
|
||||
query = SimpleStatement(query, consistency_level=consistency_level)
|
||||
|
||||
@@ -289,8 +289,12 @@ class AbstractQuerySet(object):
|
||||
# results cache
|
||||
self._result_cache = None
|
||||
self._result_idx = None
|
||||
self._result_generator = None
|
||||
|
||||
self._distinct_fields = None
|
||||
|
||||
self._count = None
|
||||
|
||||
self._batch = None
|
||||
self._ttl = getattr(model, '__default_ttl__', None)
|
||||
self._consistency = None
|
||||
@@ -298,6 +302,7 @@ class AbstractQuerySet(object):
|
||||
self._if_not_exists = False
|
||||
self._timeout = connection.NOT_SET
|
||||
self._if_exists = False
|
||||
self._fetch_size = None
|
||||
|
||||
@property
|
||||
def column_family_name(self):
|
||||
@@ -324,7 +329,7 @@ class AbstractQuerySet(object):
|
||||
def __deepcopy__(self, memo):
|
||||
clone = self.__class__(self.model)
|
||||
for k, v in self.__dict__.items():
|
||||
if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these
|
||||
if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator']: # don't clone these
|
||||
clone.__dict__[k] = None
|
||||
elif k == '_batch':
|
||||
# we need to keep the same batch instance across
|
||||
@@ -341,7 +346,7 @@ class AbstractQuerySet(object):
|
||||
|
||||
def __len__(self):
|
||||
self._execute_query()
|
||||
return len(self._result_cache)
|
||||
return self.count()
|
||||
|
||||
# ----query generation / execution----
|
||||
|
||||
@@ -365,7 +370,8 @@ class AbstractQuerySet(object):
|
||||
order_by=self._order,
|
||||
limit=self._limit,
|
||||
allow_filtering=self._allow_filtering,
|
||||
distinct_fields=self._distinct_fields
|
||||
distinct_fields=self._distinct_fields,
|
||||
fetch_size=self._fetch_size
|
||||
)
|
||||
|
||||
# ----Reads------
|
||||
@@ -374,7 +380,8 @@ class AbstractQuerySet(object):
|
||||
if self._batch:
|
||||
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
|
||||
if self._result_cache is None:
|
||||
self._result_cache = list(self._execute(self._select_query()))
|
||||
self._result_generator = (i for i in self._execute(self._select_query()))
|
||||
self._result_cache = []
|
||||
self._construct_result = self._get_result_constructor()
|
||||
|
||||
def _fill_result_cache_to_idx(self, idx):
|
||||
@@ -388,44 +395,63 @@ class AbstractQuerySet(object):
|
||||
else:
|
||||
for idx in range(qty):
|
||||
self._result_idx += 1
|
||||
self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx])
|
||||
while True:
|
||||
try:
|
||||
self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx])
|
||||
break
|
||||
except IndexError:
|
||||
self._result_cache.append(next(self._result_generator))
|
||||
|
||||
def __iter__(self):
|
||||
self._execute_query()
|
||||
|
||||
for idx in range(len(self._result_cache)):
|
||||
idx = 0
|
||||
while True:
|
||||
if len(self._result_cache) <= idx:
|
||||
try:
|
||||
self._result_cache.append(next(self._result_generator))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
instance = self._result_cache[idx]
|
||||
if isinstance(instance, dict):
|
||||
self._fill_result_cache_to_idx(idx)
|
||||
yield self._result_cache[idx]
|
||||
|
||||
idx += 1
|
||||
|
||||
def __getitem__(self, s):
|
||||
self._execute_query()
|
||||
|
||||
num_results = len(self._result_cache)
|
||||
|
||||
if isinstance(s, slice):
|
||||
# calculate the amount of results that need to be loaded
|
||||
end = num_results if s.step is None else s.step
|
||||
if end < 0:
|
||||
end += num_results
|
||||
else:
|
||||
end -= 1
|
||||
self._fill_result_cache_to_idx(end)
|
||||
end = s.stop
|
||||
if s.start < 0 or s.stop < 0:
|
||||
end = self.count()
|
||||
|
||||
try:
|
||||
self._fill_result_cache_to_idx(end)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
return self._result_cache[s.start:s.stop:s.step]
|
||||
else:
|
||||
# return the object at this index
|
||||
s = int(s)
|
||||
try:
|
||||
s = int(s)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError('QuerySet indices must be integers')
|
||||
|
||||
# handle negative indexing
|
||||
# Using negative indexing is costly since we have to execute a count()
|
||||
if s < 0:
|
||||
num_results = self.count()
|
||||
s += num_results
|
||||
|
||||
if s >= num_results:
|
||||
raise IndexError
|
||||
else:
|
||||
try:
|
||||
self._fill_result_cache_to_idx(s)
|
||||
return self._result_cache[s]
|
||||
except StopIteration:
|
||||
raise IndexError
|
||||
|
||||
return self._result_cache[s]
|
||||
|
||||
def _get_result_constructor(self):
|
||||
"""
|
||||
@@ -615,10 +641,10 @@ class AbstractQuerySet(object):
|
||||
return self.filter(*args, **kwargs).get()
|
||||
|
||||
self._execute_query()
|
||||
if len(self._result_cache) == 0:
|
||||
if self.count() == 0:
|
||||
raise self.model.DoesNotExist
|
||||
elif len(self._result_cache) > 1:
|
||||
raise self.model.MultipleObjectsReturned('{0} objects found'.format(len(self._result_cache)))
|
||||
elif self.count() > 1:
|
||||
raise self.model.MultipleObjectsReturned('{0} objects found'.format(self.count()))
|
||||
else:
|
||||
return self[0]
|
||||
|
||||
@@ -679,13 +705,13 @@ class AbstractQuerySet(object):
|
||||
if self._batch:
|
||||
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
|
||||
|
||||
if self._result_cache is None:
|
||||
if self._count is None:
|
||||
query = self._select_query()
|
||||
query.count = True
|
||||
result = self._execute(query)
|
||||
return result[0]['count']
|
||||
else:
|
||||
return len(self._result_cache)
|
||||
count_row = result[0].popitem()
|
||||
self._count = count_row[1]
|
||||
return self._count
|
||||
|
||||
def distinct(self, distinct_fields=None):
|
||||
"""
|
||||
@@ -734,7 +760,11 @@ class AbstractQuerySet(object):
|
||||
for user in User.objects().limit(100):
|
||||
print(user)
|
||||
"""
|
||||
if not (v is None or isinstance(v, six.integer_types)):
|
||||
|
||||
if v is None:
|
||||
v = 0
|
||||
|
||||
if not isinstance(v, six.integer_types):
|
||||
raise TypeError
|
||||
if v == self._limit:
|
||||
return self
|
||||
@@ -746,6 +776,30 @@ class AbstractQuerySet(object):
|
||||
clone._limit = v
|
||||
return clone
|
||||
|
||||
def fetch_size(self, v):
|
||||
"""
|
||||
Sets the number of rows that are fetched at a time.
|
||||
|
||||
*Note that driver's default fetch size is 5000.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
for user in User.objects().fetch_size(500):
|
||||
print(user)
|
||||
"""
|
||||
|
||||
if not isinstance(v, six.integer_types):
|
||||
raise TypeError
|
||||
if v == self._fetch_size:
|
||||
return self
|
||||
|
||||
if v < 1:
|
||||
raise QueryException("fetch size less than 1 is not allowed")
|
||||
|
||||
clone = copy.deepcopy(self)
|
||||
clone._fetch_size = v
|
||||
return clone
|
||||
|
||||
def allow_filtering(self):
|
||||
"""
|
||||
Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key
|
||||
|
||||
@@ -16,6 +16,7 @@ from datetime import datetime, timedelta
|
||||
import time
|
||||
import six
|
||||
|
||||
from cassandra.query import FETCH_SIZE_UNSET
|
||||
from cassandra.cqlengine import UnicodeMixin
|
||||
from cassandra.cqlengine.functions import QueryValue
|
||||
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator
|
||||
@@ -470,13 +471,14 @@ class MapDeleteClause(BaseDeleteClause):
|
||||
class BaseCQLStatement(UnicodeMixin):
|
||||
""" The base cql statement class """
|
||||
|
||||
def __init__(self, table, consistency=None, timestamp=None, where=None):
|
||||
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None):
|
||||
super(BaseCQLStatement, self).__init__()
|
||||
self.table = table
|
||||
self.consistency = consistency
|
||||
self.context_id = 0
|
||||
self.context_counter = self.context_id
|
||||
self.timestamp = timestamp
|
||||
self.fetch_size = fetch_size if fetch_size else FETCH_SIZE_UNSET
|
||||
|
||||
self.where_clauses = []
|
||||
for clause in where or []:
|
||||
@@ -556,7 +558,8 @@ class SelectStatement(BaseCQLStatement):
|
||||
order_by=None,
|
||||
limit=None,
|
||||
allow_filtering=False,
|
||||
distinct_fields=None):
|
||||
distinct_fields=None,
|
||||
fetch_size=None):
|
||||
|
||||
"""
|
||||
:param where
|
||||
@@ -565,7 +568,8 @@ class SelectStatement(BaseCQLStatement):
|
||||
super(SelectStatement, self).__init__(
|
||||
table,
|
||||
consistency=consistency,
|
||||
where=where
|
||||
where=where,
|
||||
fetch_size=fetch_size
|
||||
)
|
||||
|
||||
self.fields = [fields] if isinstance(fields, six.string_types) else (fields or [])
|
||||
@@ -577,10 +581,13 @@ class SelectStatement(BaseCQLStatement):
|
||||
|
||||
def __unicode__(self):
|
||||
qs = ['SELECT']
|
||||
if self.count:
|
||||
if self.distinct_fields:
|
||||
if self.count:
|
||||
qs += ['DISTINCT COUNT({0})'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
|
||||
else:
|
||||
qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
|
||||
elif self.count:
|
||||
qs += ['COUNT(*)']
|
||||
elif self.distinct_fields:
|
||||
qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
|
||||
else:
|
||||
qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*']
|
||||
qs += ['FROM', self.table]
|
||||
|
||||
@@ -485,6 +485,13 @@ class TestQuerySetDistinct(BaseQuerySetUsage):
|
||||
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[52])
|
||||
self.assertEqual(len(q), 0)
|
||||
|
||||
def test_distinct_with_explicit_count(self):
|
||||
q = TestModel.objects.distinct(['test_id'])
|
||||
self.assertEqual(q.count(), 3)
|
||||
|
||||
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2])
|
||||
self.assertEqual(q.count(), 2)
|
||||
|
||||
|
||||
class TestQuerySetOrdering(BaseQuerySetUsage):
|
||||
|
||||
@@ -554,16 +561,31 @@ class TestQuerySetSlicing(BaseQuerySetUsage):
|
||||
def test_slicing_works_properly(self):
|
||||
q = TestModel.objects(test_id=0).order_by('attempt_id')
|
||||
expected_order = [0, 1, 2, 3]
|
||||
|
||||
for model, expect in zip(q[1:3], expected_order[1:3]):
|
||||
assert model.attempt_id == expect
|
||||
self.assertEqual(model.attempt_id, expect)
|
||||
|
||||
for model, expect in zip(q[0:3:2], expected_order[0:3:2]):
|
||||
self.assertEqual(model.attempt_id, expect)
|
||||
|
||||
def test_negative_slicing(self):
|
||||
q = TestModel.objects(test_id=0).order_by('attempt_id')
|
||||
expected_order = [0, 1, 2, 3]
|
||||
|
||||
for model, expect in zip(q[-3:], expected_order[-3:]):
|
||||
assert model.attempt_id == expect
|
||||
self.assertEqual(model.attempt_id, expect)
|
||||
|
||||
for model, expect in zip(q[:-1], expected_order[:-1]):
|
||||
assert model.attempt_id == expect
|
||||
self.assertEqual(model.attempt_id, expect)
|
||||
|
||||
for model, expect in zip(q[1:-1], expected_order[1:-1]):
|
||||
self.assertEqual(model.attempt_id, expect)
|
||||
|
||||
for model, expect in zip(q[-3:-1], expected_order[-3:-1]):
|
||||
self.assertEqual(model.attempt_id, expect)
|
||||
|
||||
for model, expect in zip(q[-3:-1:2], expected_order[-3:-1:2]):
|
||||
self.assertEqual(model.attempt_id, expect)
|
||||
|
||||
|
||||
class TestQuerySetValidation(BaseQuerySetUsage):
|
||||
|
||||
@@ -16,6 +16,7 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.query import FETCH_SIZE_UNSET
|
||||
from cassandra.cqlengine.statements import BaseCQLStatement, StatementException
|
||||
|
||||
|
||||
@@ -26,3 +27,14 @@ class BaseStatementTest(unittest.TestCase):
|
||||
stmt = BaseCQLStatement('table', [])
|
||||
with self.assertRaises(StatementException):
|
||||
stmt.add_where_clause('x=5')
|
||||
|
||||
def test_fetch_size(self):
|
||||
""" tests that fetch_size is correctly set """
|
||||
stmt = BaseCQLStatement('table', None, fetch_size=1000)
|
||||
self.assertEqual(stmt.fetch_size, 1000)
|
||||
|
||||
stmt = BaseCQLStatement('table', None, fetch_size=None)
|
||||
self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET)
|
||||
|
||||
stmt = BaseCQLStatement('table', None)
|
||||
self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET)
|
||||
|
||||
@@ -64,6 +64,9 @@ class SelectStatementTests(unittest.TestCase):
|
||||
ss = SelectStatement('table', distinct_fields=['field1', 'field2'])
|
||||
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field1", "field2" FROM table')
|
||||
|
||||
ss = SelectStatement('table', distinct_fields=['field1'], count=True)
|
||||
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT COUNT("field1") FROM table')
|
||||
|
||||
def test_context(self):
|
||||
ss = SelectStatement('table')
|
||||
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
@@ -93,3 +96,15 @@ class SelectStatementTests(unittest.TestCase):
|
||||
self.assertIn('ORDER BY x, y', qstr)
|
||||
self.assertIn('ALLOW FILTERING', qstr)
|
||||
|
||||
def test_limit_rendering(self):
|
||||
ss = SelectStatement('table', None, limit=10)
|
||||
qstr = six.text_type(ss)
|
||||
self.assertIn('LIMIT 10', qstr)
|
||||
|
||||
ss = SelectStatement('table', None, limit=0)
|
||||
qstr = six.text_type(ss)
|
||||
self.assertNotIn('LIMIT', qstr)
|
||||
|
||||
ss = SelectStatement('table', None, limit=None)
|
||||
qstr = six.text_type(ss)
|
||||
self.assertNotIn('LIMIT', qstr)
|
||||
|
||||
Reference in New Issue
Block a user