adding connection pool support to the queryset

This commit is contained in:
Blake Eggleston
2012-12-04 22:30:07 -08:00
parent 5118c36de9
commit 43cbca139d
2 changed files with 47 additions and 9 deletions

View File

@@ -138,7 +138,8 @@ class QuerySet(object):
self._only_fields = [] self._only_fields = []
#results cache #results cache
self._cursor = None self._con = None
self._cur = None
self._result_cache = None self._result_cache = None
self._result_idx = None self._result_idx = None
@@ -154,7 +155,7 @@ class QuerySet(object):
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
clone = self.__class__(self.model) clone = self.__class__(self.model)
for k,v in self.__dict__.items(): for k,v in self.__dict__.items():
if k in ['_cursor', '_result_cache', '_result_idx']: if k in ['_con', '_cur', '_result_cache', '_result_idx']:
clone.__dict__[k] = None clone.__dict__[k] = None
else: else:
clone.__dict__[k] = copy.deepcopy(v, memo) clone.__dict__[k] = copy.deepcopy(v, memo)
@@ -164,6 +165,12 @@ class QuerySet(object):
def __len__(self): def __len__(self):
return self.count() return self.count()
def __del__(self):
if self._con:
self._con.close()
self._con = None
self._cur = None
#----query generation / execution---- #----query generation / execution----
def _validate_where_syntax(self): def _validate_where_syntax(self):
@@ -216,25 +223,31 @@ class QuerySet(object):
def _execute_query(self): def _execute_query(self):
if self._result_cache is None: if self._result_cache is None:
with connection_manager() as con: self._con = connection_manager()
self._cursor = con.execute(self._select_query(), self._where_values()) self._cur = self._con.execute(self._select_query(), self._where_values())
self._result_cache = [None]*self._cursor.rowcount self._result_cache = [None]*self._cur.rowcount
def _fill_result_cache_to_idx(self, idx): def _fill_result_cache_to_idx(self, idx):
self._execute_query() self._execute_query()
if self._result_idx is None: if self._result_idx is None:
self._result_idx = -1 self._result_idx = -1
names = [i[0] for i in self._cursor.description]
qty = idx - self._result_idx qty = idx - self._result_idx
if qty < 1: if qty < 1:
return return
else: else:
for values in self._cursor.fetchmany(qty): names = [i[0] for i in self._cur.description]
for values in self._cur.fetchmany(qty):
value_dict = dict(zip(names, values)) value_dict = dict(zip(names, values))
self._result_idx += 1 self._result_idx += 1
self._result_cache[self._result_idx] = self._construct_instance(value_dict) self._result_cache[self._result_idx] = self._construct_instance(value_dict)
#return the connection to the connection pool if we have all objects
if self._result_cache and self._result_cache[-1] is not None:
self._con.close()
self._con = None
self._cur = None
def __iter__(self): def __iter__(self):
self._execute_query() self._execute_query()

View File

@@ -355,8 +355,33 @@ class TestQuerySetDelete(BaseQuerySetUsage):
with self.assertRaises(query.QueryException): with self.assertRaises(query.QueryException):
TestModel.objects(attempt_id=0).delete() TestModel.objects(attempt_id=0).delete()
class TestQuerySetConnectionHandling(BaseQuerySetUsage):
def test_conn_is_returned_after_filling_cache(self):
"""
Tests that the queryset returns it's connection after it's fetched all of it's results
"""
q = TestModel.objects(test_id=0)
#tuple of expected attempt_id, expected_result values
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert q._con is None
assert q._cur is None
def test_conn_is_returned_after_queryset_is_garbage_collected(self):
""" Tests that the connection is returned to the connection pool after the queryset is gc'd """
from cqlengine.connection import ConnectionPool
assert ConnectionPool._queue.qsize() == 1
q = TestModel.objects(test_id=0)
v = q[0]
assert ConnectionPool._queue.qsize() == 0
del q
assert ConnectionPool._queue.qsize() == 1