Use SQLAlchemy native pre-ping

This functionality has been available upstream since SQLALchemy 1.2 [1].
However, for oslo.db to use this feature while maintaining full
behavior, we need at least SQLAlchemy 2.0.5 to provide complete event
support.  In particular, oslo.db adds several new "is disconnect"
conditions including those specific to Galera.

Behavior of the handle_error event is modified to expect the "pre-ping"
calling form, which may not have an "engine" present (only a dialect),
and additionally takes advantage of the new is_pre_ping attribute which
informs on the correct way to affect the disconnection status within the
ping handler.

Change-Id: I50d862d3cbb126987a63209795352c6e801ed919
This commit is contained in:
Stephen Finucane 2023-03-22 11:15:16 +00:00
parent 1f003bcb0b
commit 64e50494f2
4 changed files with 267 additions and 111 deletions

View File

@ -18,6 +18,8 @@ from sqlalchemy import __version__
_vers = versionutils.convert_version_to_tuple(__version__)
sqla_2 = _vers >= (2, )
native_pre_ping_event_support = _vers >= (2, 0, 5)
def dialect_from_exception_context(ctx):
if sqla_2:

View File

@ -60,6 +60,12 @@ def _connect_ping_listener(connection, branch):
Ping the server at transaction begin and transparently reconnect
if a disconnect exception occurs.
This listener is used up until SQLAlchemy 2.0.5. At 2.0.5, we use the
``pool_pre_ping`` parameter instead of this event handler.
Note the current test suite in test_exc_filters still **tests** this
handler using all SQLAlchemy versions including 2.0.5 and greater.
"""
if branch:
return
@ -199,8 +205,11 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
_vet_url(url)
_native_pre_ping = compat.native_pre_ping_event_support
engine_args = {
'pool_recycle': connection_recycle_time,
'pool_pre_ping': _native_pre_ping,
'connect_args': {},
'logging_name': logging_name
}
@ -236,9 +245,10 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
# register alternate exception handler
exc_filters.register_engine(engine)
# register engine connect handler
if not _native_pre_ping:
# register engine connect handler.
event.listen(engine, "engine_connect", _connect_ping_listener)
event.listen(engine, "engine_connect", _connect_ping_listener)
# initial connect + test
# NOTE(viktors): the current implementation of _test_connection()

View File

@ -20,7 +20,7 @@ from sqlalchemy import event
from sqlalchemy import exc as sqla_exc
from oslo_db import exception
from oslo_db.sqlalchemy import compat
LOG = logging.getLogger(__name__)
@ -377,6 +377,7 @@ def _raise_operational_errors_directly_filter(operational_error,
def _is_db_connection_error(operational_error, match, engine_name,
is_disconnect):
"""Detect the exception as indicating a recoverable error on connect."""
raise exception.DBConnectionError(operational_error)
@ -423,13 +424,14 @@ def handler(context):
more specific exception class are attempted first.
"""
def _dialect_registries(engine):
if engine.dialect.name in _registry:
yield _registry[engine.dialect.name]
def _dialect_registries(dialect):
if dialect.name in _registry:
yield _registry[dialect.name]
if '*' in _registry:
yield _registry['*']
for per_dialect in _dialect_registries(context.engine):
dialect = compat.dialect_from_exception_context(context)
for per_dialect in _dialect_registries(dialect):
for exc in (
context.sqlalchemy_exception,
context.original_exception):
@ -443,7 +445,7 @@ def handler(context):
fn(
exc,
match,
context.engine.dialect.name,
dialect.name,
context.is_disconnect)
except exception.DBError as dbe:
if (
@ -460,6 +462,19 @@ def handler(context):
if isinstance(
dbe, exception.DBConnectionError):
context.is_disconnect = True
# new in 2.0.5
if (
hasattr(context, "is_pre_ping") and
context.is_pre_ping
):
# if this is a pre-ping, need to
# integrate with the built
# in pre-ping handler that doesnt know
# about DBConnectionError, just needs
# the updated status
return None
return dbe

View File

@ -1190,41 +1190,16 @@ class IntegrationTest(db_test_base._DbTestCase):
self.assertIn("no such function", str(matched))
class TestDBDisconnected(TestsExceptionFilter):
@contextlib.contextmanager
def _fixture(
self,
dialect_name, exception, num_disconnects, is_disconnect=True):
engine = self.engine
event.listen(engine, "engine_connect", engines._connect_ping_listener)
real_do_execute = engine.dialect.do_execute
counter = itertools.count(1)
def fake_do_execute(self, *arg, **kw):
if next(counter) > num_disconnects:
return real_do_execute(self, *arg, **kw)
else:
raise exception
with self._dbapi_fixture(dialect_name):
with test_utils.nested(
mock.patch.object(engine.dialect,
"do_execute",
fake_do_execute),
mock.patch.object(engine.dialect,
"is_disconnect",
mock.Mock(return_value=is_disconnect))
):
yield
class TestDBDisconnectedFixture(TestsExceptionFilter):
native_pre_ping = False
def _test_ping_listener_disconnected(
self, dialect_name, exc_obj, is_disconnect=True,
):
with self._fixture(dialect_name, exc_obj, 1, is_disconnect):
conn = self.engine.connect()
with self._fixture(
dialect_name, exc_obj, False, is_disconnect,
) as engine:
conn = engine.connect()
with conn.begin():
self.assertEqual(
1, conn.execute(sqla.select(1)).scalars().first(),
@ -1233,19 +1208,145 @@ class TestDBDisconnected(TestsExceptionFilter):
self.assertFalse(conn.invalidated)
self.assertTrue(conn.in_transaction())
with self._fixture(dialect_name, exc_obj, 2, is_disconnect):
with self._fixture(
dialect_name, exc_obj, True, is_disconnect,
) as engine:
self.assertRaises(
exception.DBConnectionError,
self.engine.connect
engine.connect
)
# test implicit execution
with self._fixture(dialect_name, exc_obj, 1):
with self.engine.connect() as conn:
with self._fixture(dialect_name, exc_obj, False) as engine:
with engine.connect() as conn:
self.assertEqual(
1, conn.execute(sqla.select(1)).scalars().first(),
)
@contextlib.contextmanager
def _fixture(
self,
dialect_name,
exception,
db_stays_down,
is_disconnect=True,
):
"""Fixture for testing the ping listener.
For SQLAlchemy 2.0, the mocking is placed more deeply in the
stack within the DBAPI connection / cursor so that we can also
effectively mock out the "pre ping" condition.
:param dialect_name: dialect to use. "postgresql" or "mysql"
:param exception: an exception class to raise
:param db_stays_down: if True, the database will stay down after the
first ping fails
:param is_disconnect: whether or not the SQLAlchemy dialect should
consider the exception object as a "disconnect error". Openstack's
own exception handlers upgrade various DB exceptions to be
"disconnect" scenarios that SQLAlchemy itself does not, such as
some specific Galera error messages.
The importance of an exception being a "disconnect error" means that
SQLAlchemy knows it can discard the connection and then reconnect.
If the error is not a "disconnection error", then it raises.
"""
connect_args = {}
patchers = []
db_disconnected = False
class DisconnectCursorMixin:
def execute(self, *arg, **kw):
if db_disconnected:
raise exception
else:
return super().execute(*arg, **kw)
if dialect_name == "postgresql":
import psycopg2.extensions
class Curs(DisconnectCursorMixin, psycopg2.extensions.cursor):
pass
connect_args = {"cursor_factory": Curs}
elif dialect_name == "mysql":
import pymysql
def fake_ping(self, *arg, **kw):
if db_disconnected:
raise exception
else:
return True
class Curs(DisconnectCursorMixin, pymysql.cursors.Cursor):
pass
connect_args = {"cursorclass": Curs}
patchers.append(
mock.patch.object(
pymysql.Connection, "ping", fake_ping
)
)
else:
raise NotImplementedError()
with mock.patch.object(
compat,
"native_pre_ping_event_support",
self.native_pre_ping,
):
engine = engines.create_engine(
self.engine.url, max_retries=0)
# 1. override how we connect. if we want the DB to be down
# for the moment, but recover, reset db_disconnected after
# connect is called. If we want the DB to stay down, then
# make sure connect raises the error also.
@event.listens_for(engine, "do_connect")
def _connect(dialect, connrec, cargs, cparams):
nonlocal db_disconnected
# while we're here, add our cursor classes to the DBAPI
# connect args
cparams.update(connect_args)
if db_disconnected:
if db_stays_down:
raise exception
else:
db_disconnected = False
# 2. initialize the dialect with a first connect
conn = engine.connect()
conn.close()
# 3. add additional patchers
patchers.extend([
mock.patch.object(
engine.dialect.dbapi,
"Error",
self.Error,
),
mock.patch.object(
engine.dialect,
"is_disconnect",
mock.Mock(return_value=is_disconnect),
),
])
with test_utils.nested(*patchers):
# "disconnect" the DB
db_disconnected = True
yield engine
class MySQLPrePingHandlerTests(
db_test_base._MySQLOpportunisticTestCase,
TestDBDisconnectedFixture,
):
def test_mariadb_error_1927(self):
for code in [1927]:
self._test_ping_listener_disconnected(
@ -1298,6 +1399,26 @@ class TestDBDisconnected(TestsExceptionFilter):
is_disconnect=False
)
def test_mysql_w_disconnect_flag(self):
for code in [2002, 2003, 2002]:
self._test_ping_listener_disconnected(
"mysql",
self.OperationalError('%d MySQL server has gone away' % code)
)
def test_mysql_wo_disconnect_flag(self):
for code in [2002, 2003]:
self._test_ping_listener_disconnected(
"mysql",
self.OperationalError('%d MySQL server has gone away' % code),
is_disconnect=False
)
class PostgreSQLPrePingHandlerTests(
db_test_base._PostgreSQLOpportunisticTestCase,
TestDBDisconnectedFixture):
def test_postgresql_ping_listener_disconnected(self):
self._test_ping_listener_disconnected(
"postgresql",
@ -1314,79 +1435,18 @@ class TestDBDisconnected(TestsExceptionFilter):
)
class TestDBConnectRetry(TestsExceptionFilter):
if compat.sqla_2:
class MySQLNativePrePingTests(MySQLPrePingHandlerTests):
native_pre_ping = True
def _run_test(self, dialect_name, exception, count, retries):
counter = itertools.count()
engine = self.engine
# empty out the connection pool
engine.dispose()
connect_fn = engine.dialect.connect
def cant_connect(*arg, **kw):
if next(counter) < count:
raise exception
else:
return connect_fn(*arg, **kw)
with self._dbapi_fixture(dialect_name):
with mock.patch.object(engine.dialect, "connect", cant_connect):
return engines._test_connection(engine, retries, .01)
def test_connect_no_retries(self):
conn = self._run_test(
"mysql",
self.OperationalError("Error: (2003) something wrong"),
2, 0
)
# didnt connect because nothing was tried
self.assertIsNone(conn)
def test_connect_inifinite_retries(self):
conn = self._run_test(
"mysql",
self.OperationalError("Error: (2003) something wrong"),
2, -1
)
# conn is good
self.assertEqual(1, conn.scalar(sqla.select(1)))
def test_connect_retry_past_failure(self):
conn = self._run_test(
"mysql",
self.OperationalError("Error: (2003) something wrong"),
2, 3
)
# conn is good
self.assertEqual(1, conn.scalar(sqla.select(1)))
def test_connect_retry_not_candidate_exception(self):
self.assertRaises(
sqla.exc.OperationalError, # remember, we pass OperationalErrors
# through at the moment :)
self._run_test,
"mysql",
self.OperationalError("Error: (2015) I can't connect period"),
2, 3
)
def test_connect_retry_stops_infailure(self):
self.assertRaises(
exception.DBConnectionError,
self._run_test,
"mysql",
self.OperationalError("Error: (2003) something wrong"),
3, 2
)
class PostgreSQLNativePrePingTests(PostgreSQLPrePingHandlerTests):
native_pre_ping = True
class TestDBConnectPingWrapping(TestsExceptionFilter):
class TestDBConnectPingListener(TestsExceptionFilter):
def setUp(self):
super(TestDBConnectPingWrapping, self).setUp()
super().setUp()
event.listen(
self.engine, "engine_connect", engines._connect_ping_listener)
@ -1475,6 +1535,75 @@ class TestDBConnectPingWrapping(TestsExceptionFilter):
)
class TestDBConnectRetry(TestsExceptionFilter):
def _run_test(self, dialect_name, exception, count, retries):
counter = itertools.count()
engine = self.engine
# empty out the connection pool
engine.dispose()
connect_fn = engine.dialect.connect
def cant_connect(*arg, **kw):
if next(counter) < count:
raise exception
else:
return connect_fn(*arg, **kw)
with self._dbapi_fixture(dialect_name):
with mock.patch.object(engine.dialect, "connect", cant_connect):
return engines._test_connection(engine, retries, .01)
def test_connect_no_retries(self):
conn = self._run_test(
"mysql",
self.OperationalError("Error: (2003) something wrong"),
2, 0
)
# didnt connect because nothing was tried
self.assertIsNone(conn)
def test_connect_inifinite_retries(self):
conn = self._run_test(
"mysql",
self.OperationalError("Error: (2003) something wrong"),
2, -1
)
# conn is good
self.assertEqual(1, conn.scalar(sqla.select(1)))
def test_connect_retry_past_failure(self):
conn = self._run_test(
"mysql",
self.OperationalError("Error: (2003) something wrong"),
2, 3
)
# conn is good
self.assertEqual(1, conn.scalar(sqla.select(1)))
def test_connect_retry_not_candidate_exception(self):
self.assertRaises(
sqla.exc.OperationalError, # remember, we pass OperationalErrors
# through at the moment :)
self._run_test,
"mysql",
self.OperationalError("Error: (2015) I can't connect period"),
2, 3
)
def test_connect_retry_stops_infailure(self):
self.assertRaises(
exception.DBConnectionError,
self._run_test,
"mysql",
self.OperationalError("Error: (2003) something wrong"),
3, 2
)
class TestsErrorHandler(TestsExceptionFilter):
def test_multiple_error_handlers(self):
handler = mock.MagicMock(return_value=None)