db_pool: customizable connection cleanup function; Thanks to Avery Fay
https://github.com/eventlet/eventlet/pull/64 Also: - PEP8 - except Exception - .put() must not catch SystemExit
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
import sys
|
||||
import time
|
||||
|
||||
@@ -11,16 +12,24 @@ from eventlet.hubs.timer import Timer
|
||||
from eventlet.greenthread import GreenThread
|
||||
|
||||
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
class ConnectTimeout(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def cleanup_rollback(conn):
|
||||
conn.rollback()
|
||||
|
||||
|
||||
class BaseConnectionPool(Pool):
|
||||
def __init__(self, db_module,
|
||||
min_size = 0, max_size = 4,
|
||||
max_idle = 10, max_age = 30,
|
||||
connect_timeout = 5,
|
||||
*args, **kwargs):
|
||||
min_size=0, max_size=4,
|
||||
max_idle=10, max_age=30,
|
||||
connect_timeout=5,
|
||||
cleanup=cleanup_rollback,
|
||||
*args, **kwargs):
|
||||
"""
|
||||
Constructs a pool with at least *min_size* connections and at most
|
||||
*max_size* connections. Uses *db_module* to construct new connections.
|
||||
@@ -49,19 +58,20 @@ class BaseConnectionPool(Pool):
|
||||
self.max_age = max_age
|
||||
self.connect_timeout = connect_timeout
|
||||
self._expiration_timer = None
|
||||
self.cleanup = cleanup
|
||||
super(BaseConnectionPool, self).__init__(min_size=min_size,
|
||||
max_size=max_size,
|
||||
order_as_stack=True)
|
||||
|
||||
def _schedule_expiration(self):
|
||||
""" Sets up a timer that will call _expire_old_connections when the
|
||||
"""Sets up a timer that will call _expire_old_connections when the
|
||||
oldest connection currently in the free pool is ready to expire. This
|
||||
is the earliest possible time that a connection could expire, thus, the
|
||||
timer will be running as infrequently as possible without missing a
|
||||
possible expiration.
|
||||
|
||||
If this function is called when a timer is already scheduled, it does
|
||||
nothing.
|
||||
nothing.
|
||||
|
||||
If max_age or max_idle is 0, _schedule_expiration likewise does nothing.
|
||||
"""
|
||||
@@ -70,8 +80,8 @@ class BaseConnectionPool(Pool):
|
||||
# on put
|
||||
return
|
||||
|
||||
if ( self._expiration_timer is not None
|
||||
and not getattr(self._expiration_timer, 'called', False)):
|
||||
if (self._expiration_timer is not None
|
||||
and not getattr(self._expiration_timer, 'called', False)):
|
||||
# the next timer is already scheduled
|
||||
return
|
||||
|
||||
@@ -97,7 +107,7 @@ class BaseConnectionPool(Pool):
|
||||
self._expiration_timer.schedule()
|
||||
|
||||
def _expire_old_connections(self, now):
|
||||
""" Iterates through the open connections contained in the pool, closing
|
||||
"""Iterates through the open connections contained in the pool, closing
|
||||
ones that have remained idle for longer than max_idle seconds, or have
|
||||
been in existence for longer than max_age seconds.
|
||||
|
||||
@@ -124,16 +134,16 @@ class BaseConnectionPool(Pool):
|
||||
self._safe_close(conn, quiet=True)
|
||||
|
||||
def _is_expired(self, now, last_used, created_at):
|
||||
""" Returns true and closes the connection if it's expired."""
|
||||
if ( self.max_idle <= 0
|
||||
or self.max_age <= 0
|
||||
or now - last_used > self.max_idle
|
||||
or now - created_at > self.max_age ):
|
||||
"""Returns true and closes the connection if it's expired.
|
||||
"""
|
||||
if (self.max_idle <= 0 or self.max_age <= 0
|
||||
or now - last_used > self.max_idle
|
||||
or now - created_at > self.max_age):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _unwrap_connection(self, conn):
|
||||
""" If the connection was wrapped by a subclass of
|
||||
"""If the connection was wrapped by a subclass of
|
||||
BaseConnectionWrapper and is still functional (as determined
|
||||
by the __nonzero__, or __bool__ in python3, method), returns
|
||||
the unwrapped connection. If anything goes wrong with this
|
||||
@@ -150,16 +160,15 @@ class BaseConnectionPool(Pool):
|
||||
pass
|
||||
return base
|
||||
|
||||
def _safe_close(self, conn, quiet = False):
|
||||
""" Closes the (already unwrapped) connection, squelching any
|
||||
exceptions."""
|
||||
def _safe_close(self, conn, quiet=False):
|
||||
"""Closes the (already unwrapped) connection, squelching any
|
||||
exceptions.
|
||||
"""
|
||||
try:
|
||||
conn.close()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
except AttributeError:
|
||||
pass # conn is None, or junk
|
||||
except:
|
||||
pass # conn is None, or junk
|
||||
except Exception:
|
||||
if not quiet:
|
||||
print("Connection.close raised: %s" % (sys.exc_info()[1]))
|
||||
|
||||
@@ -193,7 +202,7 @@ class BaseConnectionPool(Pool):
|
||||
wrapped._db_pool_created_at = created_at
|
||||
return wrapped
|
||||
|
||||
def put(self, conn):
|
||||
def put(self, conn, cleanup=_MISSING):
|
||||
created_at = getattr(conn, '_db_pool_created_at', 0)
|
||||
now = time.time()
|
||||
conn = self._unwrap_connection(conn)
|
||||
@@ -201,35 +210,46 @@ class BaseConnectionPool(Pool):
|
||||
if self._is_expired(now, now, created_at):
|
||||
self._safe_close(conn, quiet=False)
|
||||
conn = None
|
||||
else:
|
||||
# rollback any uncommitted changes, so that the next client
|
||||
# has a clean slate. This also pokes the connection to see if
|
||||
# it's dead or None
|
||||
elif cleanup is not None:
|
||||
if cleanup is _MISSING:
|
||||
cleanup = self.cleanup
|
||||
# by default, call rollback in case the connection is in the middle
|
||||
# of a transaction. However, rollback has performance implications
|
||||
# so optionally do nothing or call something else like ping
|
||||
try:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except:
|
||||
cleanup(conn)
|
||||
except Exception as e:
|
||||
# we don't care what the exception was, we just know the
|
||||
# connection is dead
|
||||
print("WARNING: connection.rollback raised: %s" % (sys.exc_info()[1]))
|
||||
print("WARNING: cleanup %s raised: %s" % (cleanup, e))
|
||||
conn = None
|
||||
except:
|
||||
conn = None
|
||||
raise
|
||||
|
||||
if conn is not None:
|
||||
super(BaseConnectionPool, self).put( (now, created_at, conn) )
|
||||
super(BaseConnectionPool, self).put((now, created_at, conn))
|
||||
else:
|
||||
# wake up any waiters with a flag value that indicates
|
||||
# they need to manufacture a connection
|
||||
if self.waiting() > 0:
|
||||
super(BaseConnectionPool, self).put(None)
|
||||
else:
|
||||
# no waiters -- just change the size
|
||||
self.current_size -= 1
|
||||
if self.waiting() > 0:
|
||||
super(BaseConnectionPool, self).put(None)
|
||||
else:
|
||||
# no waiters -- just change the size
|
||||
self.current_size -= 1
|
||||
self._schedule_expiration()
|
||||
|
||||
@contextmanager
|
||||
def item(self, cleanup=_MISSING):
|
||||
conn = self.get()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self.put(conn, cleanup=cleanup)
|
||||
|
||||
def clear(self):
|
||||
""" Close all connections that this pool still holds a reference to,
|
||||
"""Close all connections that this pool still holds a reference to,
|
||||
and removes all references to them.
|
||||
"""
|
||||
if self._expiration_timer:
|
||||
@@ -250,8 +270,8 @@ class TpooledConnectionPool(BaseConnectionPool):
|
||||
"""
|
||||
def create(self):
|
||||
now = time.time()
|
||||
return now, now, self.connect(self._db_module,
|
||||
self.connect_timeout, *self._args, **self._kwargs)
|
||||
return now, now, self.connect(
|
||||
self._db_module, self.connect_timeout, *self._args, **self._kwargs)
|
||||
|
||||
@classmethod
|
||||
def connect(cls, db_module, connect_timeout, *args, **kw):
|
||||
@@ -269,8 +289,8 @@ class RawConnectionPool(BaseConnectionPool):
|
||||
"""
|
||||
def create(self):
|
||||
now = time.time()
|
||||
return now, now, self.connect(self._db_module,
|
||||
self.connect_timeout, *self._args, **self._kwargs)
|
||||
return now, now, self.connect(
|
||||
self._db_module, self.connect_timeout, *self._args, **self._kwargs)
|
||||
|
||||
@classmethod
|
||||
def connect(cls, db_module, connect_timeout, *args, **kw):
|
||||
@@ -288,20 +308,27 @@ ConnectionPool = TpooledConnectionPool
|
||||
class GenericConnectionWrapper(object):
|
||||
def __init__(self, baseconn):
|
||||
self._base = baseconn
|
||||
|
||||
# Proxy all method calls to self._base
|
||||
# FIXME: remove repetition; options to consider:
|
||||
# * for name in (...):
|
||||
# setattr(class, name, lambda self, *a, **kw: getattr(self._base, name)(*a, **kw))
|
||||
# * def __getattr__(self, name): if name in (...): return getattr(self._base, name)
|
||||
# * other?
|
||||
def __enter__(self): return self._base.__enter__()
|
||||
def __exit__(self, exc, value, tb): return self._base.__exit__(exc, value, tb)
|
||||
def __repr__(self): return self._base.__repr__()
|
||||
def affected_rows(self): return self._base.affected_rows()
|
||||
def autocommit(self,*args, **kwargs): return self._base.autocommit(*args, **kwargs)
|
||||
def autocommit(self, *args, **kwargs): return self._base.autocommit(*args, **kwargs)
|
||||
def begin(self): return self._base.begin()
|
||||
def change_user(self,*args, **kwargs): return self._base.change_user(*args, **kwargs)
|
||||
def character_set_name(self,*args, **kwargs): return self._base.character_set_name(*args, **kwargs)
|
||||
def close(self,*args, **kwargs): return self._base.close(*args, **kwargs)
|
||||
def commit(self,*args, **kwargs): return self._base.commit(*args, **kwargs)
|
||||
def change_user(self, *args, **kwargs): return self._base.change_user(*args, **kwargs)
|
||||
def character_set_name(self, *args, **kwargs): return self._base.character_set_name(*args, **kwargs)
|
||||
def close(self, *args, **kwargs): return self._base.close(*args, **kwargs)
|
||||
def commit(self, *args, **kwargs): return self._base.commit(*args, **kwargs)
|
||||
def cursor(self, *args, **kwargs): return self._base.cursor(*args, **kwargs)
|
||||
def dump_debug_info(self,*args, **kwargs): return self._base.dump_debug_info(*args, **kwargs)
|
||||
def errno(self,*args, **kwargs): return self._base.errno(*args, **kwargs)
|
||||
def error(self,*args, **kwargs): return self._base.error(*args, **kwargs)
|
||||
def dump_debug_info(self, *args, **kwargs): return self._base.dump_debug_info(*args, **kwargs)
|
||||
def errno(self, *args, **kwargs): return self._base.errno(*args, **kwargs)
|
||||
def error(self, *args, **kwargs): return self._base.error(*args, **kwargs)
|
||||
def errorhandler(self, *args, **kwargs): return self._base.errorhandler(*args, **kwargs)
|
||||
def insert_id(self, *args, **kwargs): return self._base.insert_id(*args, **kwargs)
|
||||
def literal(self, *args, **kwargs): return self._base.literal(*args, **kwargs)
|
||||
@@ -309,23 +336,23 @@ class GenericConnectionWrapper(object):
|
||||
def set_sql_mode(self, *args, **kwargs): return self._base.set_sql_mode(*args, **kwargs)
|
||||
def show_warnings(self): return self._base.show_warnings()
|
||||
def warning_count(self): return self._base.warning_count()
|
||||
def ping(self,*args, **kwargs): return self._base.ping(*args, **kwargs)
|
||||
def query(self,*args, **kwargs): return self._base.query(*args, **kwargs)
|
||||
def rollback(self,*args, **kwargs): return self._base.rollback(*args, **kwargs)
|
||||
def select_db(self,*args, **kwargs): return self._base.select_db(*args, **kwargs)
|
||||
def set_server_option(self,*args, **kwargs): return self._base.set_server_option(*args, **kwargs)
|
||||
def server_capabilities(self,*args, **kwargs): return self._base.server_capabilities(*args, **kwargs)
|
||||
def shutdown(self,*args, **kwargs): return self._base.shutdown(*args, **kwargs)
|
||||
def sqlstate(self,*args, **kwargs): return self._base.sqlstate(*args, **kwargs)
|
||||
def ping(self, *args, **kwargs): return self._base.ping(*args, **kwargs)
|
||||
def query(self, *args, **kwargs): return self._base.query(*args, **kwargs)
|
||||
def rollback(self, *args, **kwargs): return self._base.rollback(*args, **kwargs)
|
||||
def select_db(self, *args, **kwargs): return self._base.select_db(*args, **kwargs)
|
||||
def set_server_option(self, *args, **kwargs): return self._base.set_server_option(*args, **kwargs)
|
||||
def server_capabilities(self, *args, **kwargs): return self._base.server_capabilities(*args, **kwargs)
|
||||
def shutdown(self, *args, **kwargs): return self._base.shutdown(*args, **kwargs)
|
||||
def sqlstate(self, *args, **kwargs): return self._base.sqlstate(*args, **kwargs)
|
||||
def stat(self, *args, **kwargs): return self._base.stat(*args, **kwargs)
|
||||
def store_result(self,*args, **kwargs): return self._base.store_result(*args, **kwargs)
|
||||
def string_literal(self,*args, **kwargs): return self._base.string_literal(*args, **kwargs)
|
||||
def thread_id(self,*args, **kwargs): return self._base.thread_id(*args, **kwargs)
|
||||
def use_result(self,*args, **kwargs): return self._base.use_result(*args, **kwargs)
|
||||
def store_result(self, *args, **kwargs): return self._base.store_result(*args, **kwargs)
|
||||
def string_literal(self, *args, **kwargs): return self._base.string_literal(*args, **kwargs)
|
||||
def thread_id(self, *args, **kwargs): return self._base.thread_id(*args, **kwargs)
|
||||
def use_result(self, *args, **kwargs): return self._base.use_result(*args, **kwargs)
|
||||
|
||||
|
||||
class PooledConnectionWrapper(GenericConnectionWrapper):
|
||||
""" A connection wrapper where:
|
||||
"""A connection wrapper where:
|
||||
- the close method returns the connection to the pool instead of closing it directly
|
||||
- ``bool(conn)`` returns a reasonable value
|
||||
- returns itself to the pool if it gets garbage collected
|
||||
@@ -347,7 +374,7 @@ class PooledConnectionWrapper(GenericConnectionWrapper):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
""" Return the connection to the pool, and remove the
|
||||
"""Return the connection to the pool, and remove the
|
||||
reference to it so that you can't use it again through this
|
||||
wrapper object.
|
||||
"""
|
||||
@@ -362,17 +389,18 @@ class PooledConnectionWrapper(GenericConnectionWrapper):
|
||||
|
||||
|
||||
class DatabaseConnector(object):
|
||||
"""\
|
||||
"""
|
||||
This is an object which will maintain a collection of database
|
||||
connection pools on a per-host basis."""
|
||||
connection pools on a per-host basis.
|
||||
"""
|
||||
def __init__(self, module, credentials,
|
||||
conn_pool=None, *args, **kwargs):
|
||||
"""\
|
||||
constructor
|
||||
"""constructor
|
||||
*module*
|
||||
Database module to use.
|
||||
*credentials*
|
||||
Mapping of hostname to connect arguments (e.g. username and password)"""
|
||||
Mapping of hostname to connect arguments (e.g. username and password)
|
||||
"""
|
||||
assert(module)
|
||||
self._conn_pool_class = conn_pool
|
||||
if self._conn_pool_class is None:
|
||||
@@ -390,15 +418,16 @@ connection pools on a per-host basis."""
|
||||
return self._credentials.get('default', None)
|
||||
|
||||
def get(self, host, dbname):
|
||||
""" Returns a ConnectionPool to the target host and schema. """
|
||||
"""Returns a ConnectionPool to the target host and schema.
|
||||
"""
|
||||
key = (host, dbname)
|
||||
if key not in self._databases:
|
||||
new_kwargs = self._kwargs.copy()
|
||||
new_kwargs['db'] = dbname
|
||||
new_kwargs['host'] = host
|
||||
new_kwargs.update(self.credentials_for(host))
|
||||
dbpool = self._conn_pool_class(self._module,
|
||||
*self._args, **new_kwargs)
|
||||
dbpool = self._conn_pool_class(
|
||||
self._module, *self._args, **new_kwargs)
|
||||
self._databases[key] = dbpool
|
||||
|
||||
return self._databases[key]
|
||||
|
@@ -7,7 +7,7 @@ import os
|
||||
import traceback
|
||||
from unittest import TestCase, main
|
||||
|
||||
from tests import skipped, skip_unless, skip_with_pyevent, get_database_auth
|
||||
from tests import mock, skipped, skip_unless, skip_with_pyevent, get_database_auth
|
||||
from eventlet import event
|
||||
from eventlet import db_pool
|
||||
from eventlet.support import six
|
||||
@@ -16,6 +16,7 @@ import eventlet
|
||||
|
||||
class DBTester(object):
|
||||
__test__ = False # so that nose doesn't try to execute this directly
|
||||
|
||||
def setUp(self):
|
||||
self.create_db()
|
||||
self.connection = None
|
||||
@@ -58,6 +59,7 @@ class Mock(object):
|
||||
|
||||
class DBConnectionPool(DBTester):
|
||||
__test__ = False # so that nose doesn't try to execute this directly
|
||||
|
||||
def setUp(self):
|
||||
super(DBConnectionPool, self).setUp()
|
||||
self.pool = self.create_pool()
|
||||
@@ -148,6 +150,7 @@ class DBConnectionPool(DBTester):
|
||||
results = []
|
||||
SHORT_QUERY = "select * from test_table"
|
||||
evt = event.Event()
|
||||
|
||||
def a_query():
|
||||
self.assert_cursor_works(curs)
|
||||
curs.execute(SHORT_QUERY)
|
||||
@@ -228,12 +231,14 @@ class DBConnectionPool(DBTester):
|
||||
SHORT_QUERY = "select * from test_table where row_id <= 20"
|
||||
|
||||
evt = event.Event()
|
||||
|
||||
def long_running_query():
|
||||
self.assert_cursor_works(curs)
|
||||
curs.execute(LONG_QUERY)
|
||||
results.append(1)
|
||||
evt.send()
|
||||
evt2 = event.Event()
|
||||
|
||||
def short_running_query():
|
||||
self.assert_cursor_works(curs2)
|
||||
curs2.execute(SHORT_QUERY)
|
||||
@@ -284,12 +289,14 @@ class DBConnectionPool(DBTester):
|
||||
|
||||
# now we're really going for 100% coverage
|
||||
x = Mock()
|
||||
|
||||
def fail():
|
||||
raise KeyboardInterrupt()
|
||||
x.close = fail
|
||||
self.assertRaises(KeyboardInterrupt, self.pool._safe_close, x)
|
||||
|
||||
x = Mock()
|
||||
|
||||
def fail2():
|
||||
raise RuntimeError("if this line has been printed, the test succeeded")
|
||||
x.close = fail2
|
||||
@@ -331,7 +338,7 @@ class DBConnectionPool(DBTester):
|
||||
self.connection = self.pool.get()
|
||||
self.connection.close()
|
||||
self.assertEqual(len(self.pool.free_items), 1)
|
||||
eventlet.sleep(0.03) # long enough to trigger idle timeout for real
|
||||
eventlet.sleep(0.03) # long enough to trigger idle timeout for real
|
||||
self.assertEqual(len(self.pool.free_items), 0)
|
||||
|
||||
@skipped
|
||||
@@ -365,7 +372,7 @@ class DBConnectionPool(DBTester):
|
||||
self.connection = self.pool.get()
|
||||
self.connection.close()
|
||||
self.assertEqual(len(self.pool.free_items), 1)
|
||||
eventlet.sleep(0.05) # long enough to trigger age timeout
|
||||
eventlet.sleep(0.05) # long enough to trigger age timeout
|
||||
self.assertEqual(len(self.pool.free_items), 0)
|
||||
|
||||
@skipped
|
||||
@@ -380,7 +387,7 @@ class DBConnectionPool(DBTester):
|
||||
self.assertEqual(len(self.pool.free_items), 1)
|
||||
eventlet.sleep(0) # not long enough to trigger the age timeout
|
||||
self.assertEqual(len(self.pool.free_items), 1)
|
||||
eventlet.sleep(0.2) # long enough to trigger age timeout
|
||||
eventlet.sleep(0.2) # long enough to trigger age timeout
|
||||
self.assertEqual(len(self.pool.free_items), 0)
|
||||
conn2.close() # should not be added to the free items
|
||||
self.assertEqual(len(self.pool.free_items), 0)
|
||||
@@ -397,12 +404,13 @@ class DBConnectionPool(DBTester):
|
||||
self.assertEqual(self.pool.free(), 0)
|
||||
self.assertEqual(self.pool.waiting(), 0)
|
||||
e = event.Event()
|
||||
|
||||
def retrieve(pool, ev):
|
||||
c = pool.get()
|
||||
ev.send(c)
|
||||
eventlet.spawn(retrieve, self.pool, e)
|
||||
eventlet.sleep(0) # these two sleeps should advance the retrieve
|
||||
eventlet.sleep(0) # coroutine until it's waiting in get()
|
||||
eventlet.sleep(0) # these two sleeps should advance the retrieve
|
||||
eventlet.sleep(0) # coroutine until it's waiting in get()
|
||||
self.assertEqual(self.pool.free(), 0)
|
||||
self.assertEqual(self.pool.waiting(), 1)
|
||||
self.pool.put(self.connection)
|
||||
@@ -420,6 +428,7 @@ class DBConnectionPool(DBTester):
|
||||
iterations = 20000
|
||||
c = self.connection.cursor()
|
||||
self.connection.commit()
|
||||
|
||||
def bench(c):
|
||||
for i in six.moves.range(iterations):
|
||||
c.execute('select 1')
|
||||
@@ -459,17 +468,18 @@ class RaisingDBModule(object):
|
||||
|
||||
class TpoolConnectionPool(DBConnectionPool):
|
||||
__test__ = False # so that nose doesn't try to execute this directly
|
||||
|
||||
def create_pool(self, min_size=0, max_size=1, max_idle=10, max_age=10,
|
||||
connect_timeout=0.5, module=None):
|
||||
if module is None:
|
||||
module = self._dbmodule
|
||||
return db_pool.TpooledConnectionPool(module,
|
||||
return db_pool.TpooledConnectionPool(
|
||||
module,
|
||||
min_size=min_size, max_size=max_size,
|
||||
max_idle=max_idle, max_age=max_age,
|
||||
connect_timeout = connect_timeout,
|
||||
connect_timeout=connect_timeout,
|
||||
**self._auth)
|
||||
|
||||
|
||||
@skip_with_pyevent
|
||||
def setUp(self):
|
||||
super(TpoolConnectionPool, self).setUp()
|
||||
@@ -480,14 +490,15 @@ class TpoolConnectionPool(DBConnectionPool):
|
||||
tpool.killall()
|
||||
|
||||
|
||||
|
||||
class RawConnectionPool(DBConnectionPool):
|
||||
__test__ = False # so that nose doesn't try to execute this directly
|
||||
|
||||
def create_pool(self, min_size=0, max_size=1, max_idle=10, max_age=10,
|
||||
connect_timeout=0.5, module=None):
|
||||
if module is None:
|
||||
module = self._dbmodule
|
||||
return db_pool.RawConnectionPool(module,
|
||||
return db_pool.RawConnectionPool(
|
||||
module,
|
||||
min_size=min_size, max_size=max_size,
|
||||
max_idle=max_idle, max_age=max_age,
|
||||
connect_timeout=connect_timeout,
|
||||
@@ -497,12 +508,52 @@ class RawConnectionPool(DBConnectionPool):
|
||||
class TestRawConnectionPool(TestCase):
|
||||
def test_issue_125(self):
|
||||
# pool = self.create_pool(min_size=3, max_size=5)
|
||||
pool = db_pool.RawConnectionPool(DummyDBModule(),
|
||||
pool = db_pool.RawConnectionPool(
|
||||
DummyDBModule(),
|
||||
dsn="dbname=test user=jessica port=5433",
|
||||
min_size=3, max_size=5)
|
||||
conn = pool.get()
|
||||
pool.put(conn)
|
||||
|
||||
def test_custom_cleanup_ok(self):
|
||||
cleanup_mock = mock.Mock()
|
||||
pool = db_pool.RawConnectionPool(DummyDBModule(), cleanup=cleanup_mock)
|
||||
conn = pool.get()
|
||||
pool.put(conn)
|
||||
assert cleanup_mock.call_count == 1
|
||||
|
||||
with pool.item() as conn:
|
||||
pass
|
||||
assert cleanup_mock.call_count == 2
|
||||
|
||||
def test_custom_cleanup_arg_error(self):
|
||||
cleanup_mock = mock.Mock(side_effect=NotImplementedError)
|
||||
pool = db_pool.RawConnectionPool(DummyDBModule())
|
||||
conn = pool.get()
|
||||
pool.put(conn, cleanup=cleanup_mock)
|
||||
assert cleanup_mock.call_count == 1
|
||||
|
||||
with pool.item(cleanup=cleanup_mock):
|
||||
pass
|
||||
assert cleanup_mock.call_count == 2
|
||||
|
||||
def test_custom_cleanup_fatal(self):
|
||||
state = [0]
|
||||
|
||||
def cleanup(conn):
|
||||
state[0] += 1
|
||||
raise KeyboardInterrupt
|
||||
|
||||
pool = db_pool.RawConnectionPool(DummyDBModule(), cleanup=cleanup)
|
||||
conn = pool.get()
|
||||
try:
|
||||
pool.put(conn)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
else:
|
||||
assert False, 'Expected KeyboardInterrupt'
|
||||
assert state[0] == 1
|
||||
|
||||
|
||||
get_auth = get_database_auth
|
||||
|
||||
|
Reference in New Issue
Block a user