Merge pull request #508 from datastax/py323
PYTHON-323: Add ability to set query's fetch_size and limit using ORM
This commit is contained in:
@@ -162,8 +162,7 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET):
|
|||||||
|
|
||||||
elif isinstance(query, BaseCQLStatement):
|
elif isinstance(query, BaseCQLStatement):
|
||||||
params = query.get_context()
|
params = query.get_context()
|
||||||
query = str(query)
|
query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
|
||||||
query = SimpleStatement(query, consistency_level=consistency_level)
|
|
||||||
|
|
||||||
elif isinstance(query, six.string_types):
|
elif isinstance(query, six.string_types):
|
||||||
query = SimpleStatement(query, consistency_level=consistency_level)
|
query = SimpleStatement(query, consistency_level=consistency_level)
|
||||||
|
|||||||
@@ -289,8 +289,12 @@ class AbstractQuerySet(object):
|
|||||||
# results cache
|
# results cache
|
||||||
self._result_cache = None
|
self._result_cache = None
|
||||||
self._result_idx = None
|
self._result_idx = None
|
||||||
|
self._result_generator = None
|
||||||
|
|
||||||
self._distinct_fields = None
|
self._distinct_fields = None
|
||||||
|
|
||||||
|
self._count = None
|
||||||
|
|
||||||
self._batch = None
|
self._batch = None
|
||||||
self._ttl = getattr(model, '__default_ttl__', None)
|
self._ttl = getattr(model, '__default_ttl__', None)
|
||||||
self._consistency = None
|
self._consistency = None
|
||||||
@@ -298,6 +302,7 @@ class AbstractQuerySet(object):
|
|||||||
self._if_not_exists = False
|
self._if_not_exists = False
|
||||||
self._timeout = connection.NOT_SET
|
self._timeout = connection.NOT_SET
|
||||||
self._if_exists = False
|
self._if_exists = False
|
||||||
|
self._fetch_size = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def column_family_name(self):
|
def column_family_name(self):
|
||||||
@@ -324,7 +329,7 @@ class AbstractQuerySet(object):
|
|||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
clone = self.__class__(self.model)
|
clone = self.__class__(self.model)
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these
|
if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator']: # don't clone these
|
||||||
clone.__dict__[k] = None
|
clone.__dict__[k] = None
|
||||||
elif k == '_batch':
|
elif k == '_batch':
|
||||||
# we need to keep the same batch instance across
|
# we need to keep the same batch instance across
|
||||||
@@ -341,7 +346,7 @@ class AbstractQuerySet(object):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
self._execute_query()
|
self._execute_query()
|
||||||
return len(self._result_cache)
|
return self.count()
|
||||||
|
|
||||||
# ----query generation / execution----
|
# ----query generation / execution----
|
||||||
|
|
||||||
@@ -365,7 +370,8 @@ class AbstractQuerySet(object):
|
|||||||
order_by=self._order,
|
order_by=self._order,
|
||||||
limit=self._limit,
|
limit=self._limit,
|
||||||
allow_filtering=self._allow_filtering,
|
allow_filtering=self._allow_filtering,
|
||||||
distinct_fields=self._distinct_fields
|
distinct_fields=self._distinct_fields,
|
||||||
|
fetch_size=self._fetch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# ----Reads------
|
# ----Reads------
|
||||||
@@ -374,7 +380,8 @@ class AbstractQuerySet(object):
|
|||||||
if self._batch:
|
if self._batch:
|
||||||
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
|
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
|
||||||
if self._result_cache is None:
|
if self._result_cache is None:
|
||||||
self._result_cache = list(self._execute(self._select_query()))
|
self._result_generator = (i for i in self._execute(self._select_query()))
|
||||||
|
self._result_cache = []
|
||||||
self._construct_result = self._get_result_constructor()
|
self._construct_result = self._get_result_constructor()
|
||||||
|
|
||||||
def _fill_result_cache_to_idx(self, idx):
|
def _fill_result_cache_to_idx(self, idx):
|
||||||
@@ -388,43 +395,62 @@ class AbstractQuerySet(object):
|
|||||||
else:
|
else:
|
||||||
for idx in range(qty):
|
for idx in range(qty):
|
||||||
self._result_idx += 1
|
self._result_idx += 1
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx])
|
self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx])
|
||||||
|
break
|
||||||
|
except IndexError:
|
||||||
|
self._result_cache.append(next(self._result_generator))
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
self._execute_query()
|
self._execute_query()
|
||||||
|
|
||||||
for idx in range(len(self._result_cache)):
|
idx = 0
|
||||||
|
while True:
|
||||||
|
if len(self._result_cache) <= idx:
|
||||||
|
try:
|
||||||
|
self._result_cache.append(next(self._result_generator))
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
instance = self._result_cache[idx]
|
instance = self._result_cache[idx]
|
||||||
if isinstance(instance, dict):
|
if isinstance(instance, dict):
|
||||||
self._fill_result_cache_to_idx(idx)
|
self._fill_result_cache_to_idx(idx)
|
||||||
yield self._result_cache[idx]
|
yield self._result_cache[idx]
|
||||||
|
|
||||||
|
idx += 1
|
||||||
|
|
||||||
def __getitem__(self, s):
|
def __getitem__(self, s):
|
||||||
self._execute_query()
|
self._execute_query()
|
||||||
|
|
||||||
num_results = len(self._result_cache)
|
|
||||||
|
|
||||||
if isinstance(s, slice):
|
if isinstance(s, slice):
|
||||||
# calculate the amount of results that need to be loaded
|
# calculate the amount of results that need to be loaded
|
||||||
end = num_results if s.step is None else s.step
|
end = s.stop
|
||||||
if end < 0:
|
if s.start < 0 or s.stop < 0:
|
||||||
end += num_results
|
end = self.count()
|
||||||
else:
|
|
||||||
end -= 1
|
try:
|
||||||
self._fill_result_cache_to_idx(end)
|
self._fill_result_cache_to_idx(end)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
return self._result_cache[s.start:s.stop:s.step]
|
return self._result_cache[s.start:s.stop:s.step]
|
||||||
else:
|
else:
|
||||||
# return the object at this index
|
try:
|
||||||
s = int(s)
|
s = int(s)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise TypeError('QuerySet indices must be integers')
|
||||||
|
|
||||||
# handle negative indexing
|
# Using negative indexing is costly since we have to execute a count()
|
||||||
if s < 0:
|
if s < 0:
|
||||||
|
num_results = self.count()
|
||||||
s += num_results
|
s += num_results
|
||||||
|
|
||||||
if s >= num_results:
|
try:
|
||||||
raise IndexError
|
|
||||||
else:
|
|
||||||
self._fill_result_cache_to_idx(s)
|
self._fill_result_cache_to_idx(s)
|
||||||
|
except StopIteration:
|
||||||
|
raise IndexError
|
||||||
|
|
||||||
return self._result_cache[s]
|
return self._result_cache[s]
|
||||||
|
|
||||||
def _get_result_constructor(self):
|
def _get_result_constructor(self):
|
||||||
@@ -615,10 +641,10 @@ class AbstractQuerySet(object):
|
|||||||
return self.filter(*args, **kwargs).get()
|
return self.filter(*args, **kwargs).get()
|
||||||
|
|
||||||
self._execute_query()
|
self._execute_query()
|
||||||
if len(self._result_cache) == 0:
|
if self.count() == 0:
|
||||||
raise self.model.DoesNotExist
|
raise self.model.DoesNotExist
|
||||||
elif len(self._result_cache) > 1:
|
elif self.count() > 1:
|
||||||
raise self.model.MultipleObjectsReturned('{0} objects found'.format(len(self._result_cache)))
|
raise self.model.MultipleObjectsReturned('{0} objects found'.format(self.count()))
|
||||||
else:
|
else:
|
||||||
return self[0]
|
return self[0]
|
||||||
|
|
||||||
@@ -679,13 +705,13 @@ class AbstractQuerySet(object):
|
|||||||
if self._batch:
|
if self._batch:
|
||||||
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
|
raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode")
|
||||||
|
|
||||||
if self._result_cache is None:
|
if self._count is None:
|
||||||
query = self._select_query()
|
query = self._select_query()
|
||||||
query.count = True
|
query.count = True
|
||||||
result = self._execute(query)
|
result = self._execute(query)
|
||||||
return result[0]['count']
|
count_row = result[0].popitem()
|
||||||
else:
|
self._count = count_row[1]
|
||||||
return len(self._result_cache)
|
return self._count
|
||||||
|
|
||||||
def distinct(self, distinct_fields=None):
|
def distinct(self, distinct_fields=None):
|
||||||
"""
|
"""
|
||||||
@@ -734,7 +760,11 @@ class AbstractQuerySet(object):
|
|||||||
for user in User.objects().limit(100):
|
for user in User.objects().limit(100):
|
||||||
print(user)
|
print(user)
|
||||||
"""
|
"""
|
||||||
if not (v is None or isinstance(v, six.integer_types)):
|
|
||||||
|
if v is None:
|
||||||
|
v = 0
|
||||||
|
|
||||||
|
if not isinstance(v, six.integer_types):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
if v == self._limit:
|
if v == self._limit:
|
||||||
return self
|
return self
|
||||||
@@ -746,6 +776,30 @@ class AbstractQuerySet(object):
|
|||||||
clone._limit = v
|
clone._limit = v
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
def fetch_size(self, v):
|
||||||
|
"""
|
||||||
|
Sets the number of rows that are fetched at a time.
|
||||||
|
|
||||||
|
*Note that driver's default fetch size is 5000.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for user in User.objects().fetch_size(500):
|
||||||
|
print(user)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(v, six.integer_types):
|
||||||
|
raise TypeError
|
||||||
|
if v == self._fetch_size:
|
||||||
|
return self
|
||||||
|
|
||||||
|
if v < 1:
|
||||||
|
raise QueryException("fetch size less than 1 is not allowed")
|
||||||
|
|
||||||
|
clone = copy.deepcopy(self)
|
||||||
|
clone._fetch_size = v
|
||||||
|
return clone
|
||||||
|
|
||||||
def allow_filtering(self):
|
def allow_filtering(self):
|
||||||
"""
|
"""
|
||||||
Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key
|
Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from datetime import datetime, timedelta
|
|||||||
import time
|
import time
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from cassandra.query import FETCH_SIZE_UNSET
|
||||||
from cassandra.cqlengine import UnicodeMixin
|
from cassandra.cqlengine import UnicodeMixin
|
||||||
from cassandra.cqlengine.functions import QueryValue
|
from cassandra.cqlengine.functions import QueryValue
|
||||||
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator
|
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator
|
||||||
@@ -470,13 +471,14 @@ class MapDeleteClause(BaseDeleteClause):
|
|||||||
class BaseCQLStatement(UnicodeMixin):
|
class BaseCQLStatement(UnicodeMixin):
|
||||||
""" The base cql statement class """
|
""" The base cql statement class """
|
||||||
|
|
||||||
def __init__(self, table, consistency=None, timestamp=None, where=None):
|
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None):
|
||||||
super(BaseCQLStatement, self).__init__()
|
super(BaseCQLStatement, self).__init__()
|
||||||
self.table = table
|
self.table = table
|
||||||
self.consistency = consistency
|
self.consistency = consistency
|
||||||
self.context_id = 0
|
self.context_id = 0
|
||||||
self.context_counter = self.context_id
|
self.context_counter = self.context_id
|
||||||
self.timestamp = timestamp
|
self.timestamp = timestamp
|
||||||
|
self.fetch_size = fetch_size if fetch_size else FETCH_SIZE_UNSET
|
||||||
|
|
||||||
self.where_clauses = []
|
self.where_clauses = []
|
||||||
for clause in where or []:
|
for clause in where or []:
|
||||||
@@ -556,7 +558,8 @@ class SelectStatement(BaseCQLStatement):
|
|||||||
order_by=None,
|
order_by=None,
|
||||||
limit=None,
|
limit=None,
|
||||||
allow_filtering=False,
|
allow_filtering=False,
|
||||||
distinct_fields=None):
|
distinct_fields=None,
|
||||||
|
fetch_size=None):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
:param where
|
:param where
|
||||||
@@ -565,7 +568,8 @@ class SelectStatement(BaseCQLStatement):
|
|||||||
super(SelectStatement, self).__init__(
|
super(SelectStatement, self).__init__(
|
||||||
table,
|
table,
|
||||||
consistency=consistency,
|
consistency=consistency,
|
||||||
where=where
|
where=where,
|
||||||
|
fetch_size=fetch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
self.fields = [fields] if isinstance(fields, six.string_types) else (fields or [])
|
self.fields = [fields] if isinstance(fields, six.string_types) else (fields or [])
|
||||||
@@ -577,10 +581,13 @@ class SelectStatement(BaseCQLStatement):
|
|||||||
|
|
||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
qs = ['SELECT']
|
qs = ['SELECT']
|
||||||
|
if self.distinct_fields:
|
||||||
if self.count:
|
if self.count:
|
||||||
qs += ['COUNT(*)']
|
qs += ['DISTINCT COUNT({0})'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
|
||||||
elif self.distinct_fields:
|
else:
|
||||||
qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
|
qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
|
||||||
|
elif self.count:
|
||||||
|
qs += ['COUNT(*)']
|
||||||
else:
|
else:
|
||||||
qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*']
|
qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*']
|
||||||
qs += ['FROM', self.table]
|
qs += ['FROM', self.table]
|
||||||
|
|||||||
@@ -485,6 +485,13 @@ class TestQuerySetDistinct(BaseQuerySetUsage):
|
|||||||
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[52])
|
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[52])
|
||||||
self.assertEqual(len(q), 0)
|
self.assertEqual(len(q), 0)
|
||||||
|
|
||||||
|
def test_distinct_with_explicit_count(self):
|
||||||
|
q = TestModel.objects.distinct(['test_id'])
|
||||||
|
self.assertEqual(q.count(), 3)
|
||||||
|
|
||||||
|
q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2])
|
||||||
|
self.assertEqual(q.count(), 2)
|
||||||
|
|
||||||
|
|
||||||
class TestQuerySetOrdering(BaseQuerySetUsage):
|
class TestQuerySetOrdering(BaseQuerySetUsage):
|
||||||
|
|
||||||
@@ -554,16 +561,31 @@ class TestQuerySetSlicing(BaseQuerySetUsage):
|
|||||||
def test_slicing_works_properly(self):
|
def test_slicing_works_properly(self):
|
||||||
q = TestModel.objects(test_id=0).order_by('attempt_id')
|
q = TestModel.objects(test_id=0).order_by('attempt_id')
|
||||||
expected_order = [0, 1, 2, 3]
|
expected_order = [0, 1, 2, 3]
|
||||||
|
|
||||||
for model, expect in zip(q[1:3], expected_order[1:3]):
|
for model, expect in zip(q[1:3], expected_order[1:3]):
|
||||||
assert model.attempt_id == expect
|
self.assertEqual(model.attempt_id, expect)
|
||||||
|
|
||||||
|
for model, expect in zip(q[0:3:2], expected_order[0:3:2]):
|
||||||
|
self.assertEqual(model.attempt_id, expect)
|
||||||
|
|
||||||
def test_negative_slicing(self):
|
def test_negative_slicing(self):
|
||||||
q = TestModel.objects(test_id=0).order_by('attempt_id')
|
q = TestModel.objects(test_id=0).order_by('attempt_id')
|
||||||
expected_order = [0, 1, 2, 3]
|
expected_order = [0, 1, 2, 3]
|
||||||
|
|
||||||
for model, expect in zip(q[-3:], expected_order[-3:]):
|
for model, expect in zip(q[-3:], expected_order[-3:]):
|
||||||
assert model.attempt_id == expect
|
self.assertEqual(model.attempt_id, expect)
|
||||||
|
|
||||||
for model, expect in zip(q[:-1], expected_order[:-1]):
|
for model, expect in zip(q[:-1], expected_order[:-1]):
|
||||||
assert model.attempt_id == expect
|
self.assertEqual(model.attempt_id, expect)
|
||||||
|
|
||||||
|
for model, expect in zip(q[1:-1], expected_order[1:-1]):
|
||||||
|
self.assertEqual(model.attempt_id, expect)
|
||||||
|
|
||||||
|
for model, expect in zip(q[-3:-1], expected_order[-3:-1]):
|
||||||
|
self.assertEqual(model.attempt_id, expect)
|
||||||
|
|
||||||
|
for model, expect in zip(q[-3:-1:2], expected_order[-3:-1:2]):
|
||||||
|
self.assertEqual(model.attempt_id, expect)
|
||||||
|
|
||||||
|
|
||||||
class TestQuerySetValidation(BaseQuerySetUsage):
|
class TestQuerySetValidation(BaseQuerySetUsage):
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
import unittest # noqa
|
import unittest # noqa
|
||||||
|
|
||||||
|
from cassandra.query import FETCH_SIZE_UNSET
|
||||||
from cassandra.cqlengine.statements import BaseCQLStatement, StatementException
|
from cassandra.cqlengine.statements import BaseCQLStatement, StatementException
|
||||||
|
|
||||||
|
|
||||||
@@ -26,3 +27,14 @@ class BaseStatementTest(unittest.TestCase):
|
|||||||
stmt = BaseCQLStatement('table', [])
|
stmt = BaseCQLStatement('table', [])
|
||||||
with self.assertRaises(StatementException):
|
with self.assertRaises(StatementException):
|
||||||
stmt.add_where_clause('x=5')
|
stmt.add_where_clause('x=5')
|
||||||
|
|
||||||
|
def test_fetch_size(self):
|
||||||
|
""" tests that fetch_size is correctly set """
|
||||||
|
stmt = BaseCQLStatement('table', None, fetch_size=1000)
|
||||||
|
self.assertEqual(stmt.fetch_size, 1000)
|
||||||
|
|
||||||
|
stmt = BaseCQLStatement('table', None, fetch_size=None)
|
||||||
|
self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET)
|
||||||
|
|
||||||
|
stmt = BaseCQLStatement('table', None)
|
||||||
|
self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET)
|
||||||
|
|||||||
@@ -64,6 +64,9 @@ class SelectStatementTests(unittest.TestCase):
|
|||||||
ss = SelectStatement('table', distinct_fields=['field1', 'field2'])
|
ss = SelectStatement('table', distinct_fields=['field1', 'field2'])
|
||||||
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field1", "field2" FROM table')
|
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field1", "field2" FROM table')
|
||||||
|
|
||||||
|
ss = SelectStatement('table', distinct_fields=['field1'], count=True)
|
||||||
|
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT COUNT("field1") FROM table')
|
||||||
|
|
||||||
def test_context(self):
|
def test_context(self):
|
||||||
ss = SelectStatement('table')
|
ss = SelectStatement('table')
|
||||||
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||||
@@ -93,3 +96,15 @@ class SelectStatementTests(unittest.TestCase):
|
|||||||
self.assertIn('ORDER BY x, y', qstr)
|
self.assertIn('ORDER BY x, y', qstr)
|
||||||
self.assertIn('ALLOW FILTERING', qstr)
|
self.assertIn('ALLOW FILTERING', qstr)
|
||||||
|
|
||||||
|
def test_limit_rendering(self):
|
||||||
|
ss = SelectStatement('table', None, limit=10)
|
||||||
|
qstr = six.text_type(ss)
|
||||||
|
self.assertIn('LIMIT 10', qstr)
|
||||||
|
|
||||||
|
ss = SelectStatement('table', None, limit=0)
|
||||||
|
qstr = six.text_type(ss)
|
||||||
|
self.assertNotIn('LIMIT', qstr)
|
||||||
|
|
||||||
|
ss = SelectStatement('table', None, limit=None)
|
||||||
|
qstr = six.text_type(ss)
|
||||||
|
self.assertNotIn('LIMIT', qstr)
|
||||||
|
|||||||
Reference in New Issue
Block a user