db_pool: customizable connection cleanup function; Thanks to Avery Fay

https://github.com/eventlet/eventlet/pull/64

Also:
- PEP8
- except Exception
- .put() must not catch SystemExit
This commit is contained in:
Sergey Shepelev
2014-06-13 17:15:47 +04:00
parent bd7a14305c
commit d2ec34fb14
2 changed files with 165 additions and 85 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import print_function
from collections import deque
from contextlib import contextmanager
import sys
import time
@@ -11,16 +12,24 @@ from eventlet.hubs.timer import Timer
from eventlet.greenthread import GreenThread
_MISSING = object()
class ConnectTimeout(Exception):
pass
def cleanup_rollback(conn):
conn.rollback()
class BaseConnectionPool(Pool):
def __init__(self, db_module,
min_size = 0, max_size = 4,
max_idle = 10, max_age = 30,
connect_timeout = 5,
*args, **kwargs):
min_size=0, max_size=4,
max_idle=10, max_age=30,
connect_timeout=5,
cleanup=cleanup_rollback,
*args, **kwargs):
"""
Constructs a pool with at least *min_size* connections and at most
*max_size* connections. Uses *db_module* to construct new connections.
@@ -49,19 +58,20 @@ class BaseConnectionPool(Pool):
self.max_age = max_age
self.connect_timeout = connect_timeout
self._expiration_timer = None
self.cleanup = cleanup
super(BaseConnectionPool, self).__init__(min_size=min_size,
max_size=max_size,
order_as_stack=True)
def _schedule_expiration(self):
""" Sets up a timer that will call _expire_old_connections when the
"""Sets up a timer that will call _expire_old_connections when the
oldest connection currently in the free pool is ready to expire. This
is the earliest possible time that a connection could expire, thus, the
timer will be running as infrequently as possible without missing a
possible expiration.
If this function is called when a timer is already scheduled, it does
nothing.
nothing.
If max_age or max_idle is 0, _schedule_expiration likewise does nothing.
"""
@@ -70,8 +80,8 @@ class BaseConnectionPool(Pool):
# on put
return
if ( self._expiration_timer is not None
and not getattr(self._expiration_timer, 'called', False)):
if (self._expiration_timer is not None
and not getattr(self._expiration_timer, 'called', False)):
# the next timer is already scheduled
return
@@ -97,7 +107,7 @@ class BaseConnectionPool(Pool):
self._expiration_timer.schedule()
def _expire_old_connections(self, now):
""" Iterates through the open connections contained in the pool, closing
"""Iterates through the open connections contained in the pool, closing
ones that have remained idle for longer than max_idle seconds, or have
been in existence for longer than max_age seconds.
@@ -124,16 +134,16 @@ class BaseConnectionPool(Pool):
self._safe_close(conn, quiet=True)
def _is_expired(self, now, last_used, created_at):
""" Returns true and closes the connection if it's expired."""
if ( self.max_idle <= 0
or self.max_age <= 0
or now - last_used > self.max_idle
or now - created_at > self.max_age ):
"""Returns true and closes the connection if it's expired.
"""
if (self.max_idle <= 0 or self.max_age <= 0
or now - last_used > self.max_idle
or now - created_at > self.max_age):
return True
return False
def _unwrap_connection(self, conn):
""" If the connection was wrapped by a subclass of
"""If the connection was wrapped by a subclass of
BaseConnectionWrapper and is still functional (as determined
by the __nonzero__, or __bool__ in python3, method), returns
the unwrapped connection. If anything goes wrong with this
@@ -150,16 +160,15 @@ class BaseConnectionPool(Pool):
pass
return base
def _safe_close(self, conn, quiet = False):
""" Closes the (already unwrapped) connection, squelching any
exceptions."""
def _safe_close(self, conn, quiet=False):
"""Closes the (already unwrapped) connection, squelching any
exceptions.
"""
try:
conn.close()
except (KeyboardInterrupt, SystemExit):
raise
except AttributeError:
pass # conn is None, or junk
except:
pass # conn is None, or junk
except Exception:
if not quiet:
print("Connection.close raised: %s" % (sys.exc_info()[1]))
@@ -193,7 +202,7 @@ class BaseConnectionPool(Pool):
wrapped._db_pool_created_at = created_at
return wrapped
def put(self, conn):
def put(self, conn, cleanup=_MISSING):
created_at = getattr(conn, '_db_pool_created_at', 0)
now = time.time()
conn = self._unwrap_connection(conn)
@@ -201,35 +210,46 @@ class BaseConnectionPool(Pool):
if self._is_expired(now, now, created_at):
self._safe_close(conn, quiet=False)
conn = None
else:
# rollback any uncommitted changes, so that the next client
# has a clean slate. This also pokes the connection to see if
# it's dead or None
elif cleanup is not None:
if cleanup is _MISSING:
cleanup = self.cleanup
# by default, call rollback in case the connection is in the middle
# of a transaction. However, rollback has performance implications
# so optionally do nothing or call something else like ping
try:
if conn:
conn.rollback()
except KeyboardInterrupt:
raise
except:
cleanup(conn)
except Exception as e:
# we don't care what the exception was, we just know the
# connection is dead
print("WARNING: connection.rollback raised: %s" % (sys.exc_info()[1]))
print("WARNING: cleanup %s raised: %s" % (cleanup, e))
conn = None
except:
conn = None
raise
if conn is not None:
super(BaseConnectionPool, self).put( (now, created_at, conn) )
super(BaseConnectionPool, self).put((now, created_at, conn))
else:
# wake up any waiters with a flag value that indicates
# they need to manufacture a connection
if self.waiting() > 0:
super(BaseConnectionPool, self).put(None)
else:
# no waiters -- just change the size
self.current_size -= 1
if self.waiting() > 0:
super(BaseConnectionPool, self).put(None)
else:
# no waiters -- just change the size
self.current_size -= 1
self._schedule_expiration()
@contextmanager
def item(self, cleanup=_MISSING):
conn = self.get()
try:
yield conn
finally:
self.put(conn, cleanup=cleanup)
def clear(self):
""" Close all connections that this pool still holds a reference to,
"""Close all connections that this pool still holds a reference to,
and removes all references to them.
"""
if self._expiration_timer:
@@ -250,8 +270,8 @@ class TpooledConnectionPool(BaseConnectionPool):
"""
def create(self):
now = time.time()
return now, now, self.connect(self._db_module,
self.connect_timeout, *self._args, **self._kwargs)
return now, now, self.connect(
self._db_module, self.connect_timeout, *self._args, **self._kwargs)
@classmethod
def connect(cls, db_module, connect_timeout, *args, **kw):
@@ -269,8 +289,8 @@ class RawConnectionPool(BaseConnectionPool):
"""
def create(self):
now = time.time()
return now, now, self.connect(self._db_module,
self.connect_timeout, *self._args, **self._kwargs)
return now, now, self.connect(
self._db_module, self.connect_timeout, *self._args, **self._kwargs)
@classmethod
def connect(cls, db_module, connect_timeout, *args, **kw):
@@ -288,20 +308,27 @@ ConnectionPool = TpooledConnectionPool
class GenericConnectionWrapper(object):
def __init__(self, baseconn):
self._base = baseconn
# Proxy all method calls to self._base
# FIXME: remove repetition; options to consider:
# * for name in (...):
# setattr(class, name, lambda self, *a, **kw: getattr(self._base, name)(*a, **kw))
# * def __getattr__(self, name): if name in (...): return getattr(self._base, name)
# * other?
def __enter__(self): return self._base.__enter__()
def __exit__(self, exc, value, tb): return self._base.__exit__(exc, value, tb)
def __repr__(self): return self._base.__repr__()
def affected_rows(self): return self._base.affected_rows()
def autocommit(self,*args, **kwargs): return self._base.autocommit(*args, **kwargs)
def autocommit(self, *args, **kwargs): return self._base.autocommit(*args, **kwargs)
def begin(self): return self._base.begin()
def change_user(self,*args, **kwargs): return self._base.change_user(*args, **kwargs)
def character_set_name(self,*args, **kwargs): return self._base.character_set_name(*args, **kwargs)
def close(self,*args, **kwargs): return self._base.close(*args, **kwargs)
def commit(self,*args, **kwargs): return self._base.commit(*args, **kwargs)
def change_user(self, *args, **kwargs): return self._base.change_user(*args, **kwargs)
def character_set_name(self, *args, **kwargs): return self._base.character_set_name(*args, **kwargs)
def close(self, *args, **kwargs): return self._base.close(*args, **kwargs)
def commit(self, *args, **kwargs): return self._base.commit(*args, **kwargs)
def cursor(self, *args, **kwargs): return self._base.cursor(*args, **kwargs)
def dump_debug_info(self,*args, **kwargs): return self._base.dump_debug_info(*args, **kwargs)
def errno(self,*args, **kwargs): return self._base.errno(*args, **kwargs)
def error(self,*args, **kwargs): return self._base.error(*args, **kwargs)
def dump_debug_info(self, *args, **kwargs): return self._base.dump_debug_info(*args, **kwargs)
def errno(self, *args, **kwargs): return self._base.errno(*args, **kwargs)
def error(self, *args, **kwargs): return self._base.error(*args, **kwargs)
def errorhandler(self, *args, **kwargs): return self._base.errorhandler(*args, **kwargs)
def insert_id(self, *args, **kwargs): return self._base.insert_id(*args, **kwargs)
def literal(self, *args, **kwargs): return self._base.literal(*args, **kwargs)
@@ -309,23 +336,23 @@ class GenericConnectionWrapper(object):
def set_sql_mode(self, *args, **kwargs): return self._base.set_sql_mode(*args, **kwargs)
def show_warnings(self): return self._base.show_warnings()
def warning_count(self): return self._base.warning_count()
def ping(self,*args, **kwargs): return self._base.ping(*args, **kwargs)
def query(self,*args, **kwargs): return self._base.query(*args, **kwargs)
def rollback(self,*args, **kwargs): return self._base.rollback(*args, **kwargs)
def select_db(self,*args, **kwargs): return self._base.select_db(*args, **kwargs)
def set_server_option(self,*args, **kwargs): return self._base.set_server_option(*args, **kwargs)
def server_capabilities(self,*args, **kwargs): return self._base.server_capabilities(*args, **kwargs)
def shutdown(self,*args, **kwargs): return self._base.shutdown(*args, **kwargs)
def sqlstate(self,*args, **kwargs): return self._base.sqlstate(*args, **kwargs)
def ping(self, *args, **kwargs): return self._base.ping(*args, **kwargs)
def query(self, *args, **kwargs): return self._base.query(*args, **kwargs)
def rollback(self, *args, **kwargs): return self._base.rollback(*args, **kwargs)
def select_db(self, *args, **kwargs): return self._base.select_db(*args, **kwargs)
def set_server_option(self, *args, **kwargs): return self._base.set_server_option(*args, **kwargs)
def server_capabilities(self, *args, **kwargs): return self._base.server_capabilities(*args, **kwargs)
def shutdown(self, *args, **kwargs): return self._base.shutdown(*args, **kwargs)
def sqlstate(self, *args, **kwargs): return self._base.sqlstate(*args, **kwargs)
def stat(self, *args, **kwargs): return self._base.stat(*args, **kwargs)
def store_result(self,*args, **kwargs): return self._base.store_result(*args, **kwargs)
def string_literal(self,*args, **kwargs): return self._base.string_literal(*args, **kwargs)
def thread_id(self,*args, **kwargs): return self._base.thread_id(*args, **kwargs)
def use_result(self,*args, **kwargs): return self._base.use_result(*args, **kwargs)
def store_result(self, *args, **kwargs): return self._base.store_result(*args, **kwargs)
def string_literal(self, *args, **kwargs): return self._base.string_literal(*args, **kwargs)
def thread_id(self, *args, **kwargs): return self._base.thread_id(*args, **kwargs)
def use_result(self, *args, **kwargs): return self._base.use_result(*args, **kwargs)
class PooledConnectionWrapper(GenericConnectionWrapper):
""" A connection wrapper where:
"""A connection wrapper where:
- the close method returns the connection to the pool instead of closing it directly
- ``bool(conn)`` returns a reasonable value
- returns itself to the pool if it gets garbage collected
@@ -347,7 +374,7 @@ class PooledConnectionWrapper(GenericConnectionWrapper):
pass
def close(self):
""" Return the connection to the pool, and remove the
"""Return the connection to the pool, and remove the
reference to it so that you can't use it again through this
wrapper object.
"""
@@ -362,17 +389,18 @@ class PooledConnectionWrapper(GenericConnectionWrapper):
class DatabaseConnector(object):
"""\
"""
This is an object which will maintain a collection of database
connection pools on a per-host basis."""
connection pools on a per-host basis.
"""
def __init__(self, module, credentials,
conn_pool=None, *args, **kwargs):
"""\
constructor
"""constructor
*module*
Database module to use.
*credentials*
Mapping of hostname to connect arguments (e.g. username and password)"""
Mapping of hostname to connect arguments (e.g. username and password)
"""
assert(module)
self._conn_pool_class = conn_pool
if self._conn_pool_class is None:
@@ -390,15 +418,16 @@ connection pools on a per-host basis."""
return self._credentials.get('default', None)
def get(self, host, dbname):
""" Returns a ConnectionPool to the target host and schema. """
"""Returns a ConnectionPool to the target host and schema.
"""
key = (host, dbname)
if key not in self._databases:
new_kwargs = self._kwargs.copy()
new_kwargs['db'] = dbname
new_kwargs['host'] = host
new_kwargs.update(self.credentials_for(host))
dbpool = self._conn_pool_class(self._module,
*self._args, **new_kwargs)
dbpool = self._conn_pool_class(
self._module, *self._args, **new_kwargs)
self._databases[key] = dbpool
return self._databases[key]

View File

@@ -7,7 +7,7 @@ import os
import traceback
from unittest import TestCase, main
from tests import skipped, skip_unless, skip_with_pyevent, get_database_auth
from tests import mock, skipped, skip_unless, skip_with_pyevent, get_database_auth
from eventlet import event
from eventlet import db_pool
from eventlet.support import six
@@ -16,6 +16,7 @@ import eventlet
class DBTester(object):
__test__ = False # so that nose doesn't try to execute this directly
def setUp(self):
self.create_db()
self.connection = None
@@ -58,6 +59,7 @@ class Mock(object):
class DBConnectionPool(DBTester):
__test__ = False # so that nose doesn't try to execute this directly
def setUp(self):
super(DBConnectionPool, self).setUp()
self.pool = self.create_pool()
@@ -148,6 +150,7 @@ class DBConnectionPool(DBTester):
results = []
SHORT_QUERY = "select * from test_table"
evt = event.Event()
def a_query():
self.assert_cursor_works(curs)
curs.execute(SHORT_QUERY)
@@ -228,12 +231,14 @@ class DBConnectionPool(DBTester):
SHORT_QUERY = "select * from test_table where row_id <= 20"
evt = event.Event()
def long_running_query():
self.assert_cursor_works(curs)
curs.execute(LONG_QUERY)
results.append(1)
evt.send()
evt2 = event.Event()
def short_running_query():
self.assert_cursor_works(curs2)
curs2.execute(SHORT_QUERY)
@@ -284,12 +289,14 @@ class DBConnectionPool(DBTester):
# now we're really going for 100% coverage
x = Mock()
def fail():
raise KeyboardInterrupt()
x.close = fail
self.assertRaises(KeyboardInterrupt, self.pool._safe_close, x)
x = Mock()
def fail2():
raise RuntimeError("if this line has been printed, the test succeeded")
x.close = fail2
@@ -331,7 +338,7 @@ class DBConnectionPool(DBTester):
self.connection = self.pool.get()
self.connection.close()
self.assertEqual(len(self.pool.free_items), 1)
eventlet.sleep(0.03) # long enough to trigger idle timeout for real
eventlet.sleep(0.03) # long enough to trigger idle timeout for real
self.assertEqual(len(self.pool.free_items), 0)
@skipped
@@ -365,7 +372,7 @@ class DBConnectionPool(DBTester):
self.connection = self.pool.get()
self.connection.close()
self.assertEqual(len(self.pool.free_items), 1)
eventlet.sleep(0.05) # long enough to trigger age timeout
eventlet.sleep(0.05) # long enough to trigger age timeout
self.assertEqual(len(self.pool.free_items), 0)
@skipped
@@ -380,7 +387,7 @@ class DBConnectionPool(DBTester):
self.assertEqual(len(self.pool.free_items), 1)
eventlet.sleep(0) # not long enough to trigger the age timeout
self.assertEqual(len(self.pool.free_items), 1)
eventlet.sleep(0.2) # long enough to trigger age timeout
eventlet.sleep(0.2) # long enough to trigger age timeout
self.assertEqual(len(self.pool.free_items), 0)
conn2.close() # should not be added to the free items
self.assertEqual(len(self.pool.free_items), 0)
@@ -397,12 +404,13 @@ class DBConnectionPool(DBTester):
self.assertEqual(self.pool.free(), 0)
self.assertEqual(self.pool.waiting(), 0)
e = event.Event()
def retrieve(pool, ev):
c = pool.get()
ev.send(c)
eventlet.spawn(retrieve, self.pool, e)
eventlet.sleep(0) # these two sleeps should advance the retrieve
eventlet.sleep(0) # coroutine until it's waiting in get()
eventlet.sleep(0) # these two sleeps should advance the retrieve
eventlet.sleep(0) # coroutine until it's waiting in get()
self.assertEqual(self.pool.free(), 0)
self.assertEqual(self.pool.waiting(), 1)
self.pool.put(self.connection)
@@ -420,6 +428,7 @@ class DBConnectionPool(DBTester):
iterations = 20000
c = self.connection.cursor()
self.connection.commit()
def bench(c):
for i in six.moves.range(iterations):
c.execute('select 1')
@@ -459,17 +468,18 @@ class RaisingDBModule(object):
class TpoolConnectionPool(DBConnectionPool):
__test__ = False # so that nose doesn't try to execute this directly
def create_pool(self, min_size=0, max_size=1, max_idle=10, max_age=10,
connect_timeout=0.5, module=None):
if module is None:
module = self._dbmodule
return db_pool.TpooledConnectionPool(module,
return db_pool.TpooledConnectionPool(
module,
min_size=min_size, max_size=max_size,
max_idle=max_idle, max_age=max_age,
connect_timeout = connect_timeout,
connect_timeout=connect_timeout,
**self._auth)
@skip_with_pyevent
def setUp(self):
super(TpoolConnectionPool, self).setUp()
@@ -480,14 +490,15 @@ class TpoolConnectionPool(DBConnectionPool):
tpool.killall()
class RawConnectionPool(DBConnectionPool):
__test__ = False # so that nose doesn't try to execute this directly
def create_pool(self, min_size=0, max_size=1, max_idle=10, max_age=10,
connect_timeout=0.5, module=None):
if module is None:
module = self._dbmodule
return db_pool.RawConnectionPool(module,
return db_pool.RawConnectionPool(
module,
min_size=min_size, max_size=max_size,
max_idle=max_idle, max_age=max_age,
connect_timeout=connect_timeout,
@@ -497,12 +508,52 @@ class RawConnectionPool(DBConnectionPool):
class TestRawConnectionPool(TestCase):
def test_issue_125(self):
# pool = self.create_pool(min_size=3, max_size=5)
pool = db_pool.RawConnectionPool(DummyDBModule(),
pool = db_pool.RawConnectionPool(
DummyDBModule(),
dsn="dbname=test user=jessica port=5433",
min_size=3, max_size=5)
conn = pool.get()
pool.put(conn)
def test_custom_cleanup_ok(self):
cleanup_mock = mock.Mock()
pool = db_pool.RawConnectionPool(DummyDBModule(), cleanup=cleanup_mock)
conn = pool.get()
pool.put(conn)
assert cleanup_mock.call_count == 1
with pool.item() as conn:
pass
assert cleanup_mock.call_count == 2
def test_custom_cleanup_arg_error(self):
cleanup_mock = mock.Mock(side_effect=NotImplementedError)
pool = db_pool.RawConnectionPool(DummyDBModule())
conn = pool.get()
pool.put(conn, cleanup=cleanup_mock)
assert cleanup_mock.call_count == 1
with pool.item(cleanup=cleanup_mock):
pass
assert cleanup_mock.call_count == 2
def test_custom_cleanup_fatal(self):
state = [0]
def cleanup(conn):
state[0] += 1
raise KeyboardInterrupt
pool = db_pool.RawConnectionPool(DummyDBModule(), cleanup=cleanup)
conn = pool.get()
try:
pool.put(conn)
except KeyboardInterrupt:
pass
else:
assert False, 'Expected KeyboardInterrupt'
assert state[0] == 1
get_auth = get_database_auth