diff --git a/oslo_db/sqlalchemy/compat/__init__.py b/oslo_db/sqlalchemy/compat/__init__.py index 6713696b..d2092071 100644 --- a/oslo_db/sqlalchemy/compat/__init__.py +++ b/oslo_db/sqlalchemy/compat/__init__.py @@ -18,6 +18,8 @@ from sqlalchemy import __version__ _vers = versionutils.convert_version_to_tuple(__version__) sqla_2 = _vers >= (2, ) +native_pre_ping_event_support = _vers >= (2, 0, 5) + def dialect_from_exception_context(ctx): if sqla_2: diff --git a/oslo_db/sqlalchemy/engines.py b/oslo_db/sqlalchemy/engines.py index 146d1890..7c36c8a8 100644 --- a/oslo_db/sqlalchemy/engines.py +++ b/oslo_db/sqlalchemy/engines.py @@ -60,6 +60,12 @@ def _connect_ping_listener(connection, branch): Ping the server at transaction begin and transparently reconnect if a disconnect exception occurs. + This listener is used up until SQLAlchemy 2.0.5. At 2.0.5, we use the + ``pool_pre_ping`` parameter instead of this event handler. + + Note the current test suite in test_exc_filters still **tests** this + handler using all SQLAlchemy versions including 2.0.5 and greater. + """ if branch: return @@ -199,8 +205,11 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, _vet_url(url) + _native_pre_ping = compat.native_pre_ping_event_support + engine_args = { 'pool_recycle': connection_recycle_time, + 'pool_pre_ping': _native_pre_ping, 'connect_args': {}, 'logging_name': logging_name } @@ -236,9 +245,10 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, # register alternate exception handler exc_filters.register_engine(engine) - # register engine connect handler + if not _native_pre_ping: + # register engine connect handler. - event.listen(engine, "engine_connect", _connect_ping_listener) + event.listen(engine, "engine_connect", _connect_ping_listener) # initial connect + test # NOTE(viktors): the current implementation of _test_connection() diff --git a/oslo_db/sqlalchemy/exc_filters.py b/oslo_db/sqlalchemy/exc_filters.py index e5789875..420b5c7d 100644 --- a/oslo_db/sqlalchemy/exc_filters.py +++ b/oslo_db/sqlalchemy/exc_filters.py @@ -20,7 +20,7 @@ from sqlalchemy import event from sqlalchemy import exc as sqla_exc from oslo_db import exception - +from oslo_db.sqlalchemy import compat LOG = logging.getLogger(__name__) @@ -377,6 +377,7 @@ def _raise_operational_errors_directly_filter(operational_error, def _is_db_connection_error(operational_error, match, engine_name, is_disconnect): """Detect the exception as indicating a recoverable error on connect.""" + raise exception.DBConnectionError(operational_error) @@ -423,13 +424,14 @@ def handler(context): more specific exception class are attempted first. """ - def _dialect_registries(engine): - if engine.dialect.name in _registry: - yield _registry[engine.dialect.name] + def _dialect_registries(dialect): + if dialect.name in _registry: + yield _registry[dialect.name] if '*' in _registry: yield _registry['*'] - for per_dialect in _dialect_registries(context.engine): + dialect = compat.dialect_from_exception_context(context) + for per_dialect in _dialect_registries(dialect): for exc in ( context.sqlalchemy_exception, context.original_exception): @@ -443,7 +445,7 @@ def handler(context): fn( exc, match, - context.engine.dialect.name, + dialect.name, context.is_disconnect) except exception.DBError as dbe: if ( @@ -460,6 +462,19 @@ def handler(context): if isinstance( dbe, exception.DBConnectionError): context.is_disconnect = True + + # new in 2.0.5 + if ( + hasattr(context, "is_pre_ping") and + context.is_pre_ping + ): + # if this is a pre-ping, need to + # integrate with the built + # in pre-ping handler that doesnt know + # about DBConnectionError, just needs + # the updated status + return None + return dbe diff --git a/oslo_db/tests/sqlalchemy/test_exc_filters.py b/oslo_db/tests/sqlalchemy/test_exc_filters.py index af3cd918..796ba6c3 100644 --- a/oslo_db/tests/sqlalchemy/test_exc_filters.py +++ b/oslo_db/tests/sqlalchemy/test_exc_filters.py @@ -1190,41 +1190,16 @@ class IntegrationTest(db_test_base._DbTestCase): self.assertIn("no such function", str(matched)) -class TestDBDisconnected(TestsExceptionFilter): - - @contextlib.contextmanager - def _fixture( - self, - dialect_name, exception, num_disconnects, is_disconnect=True): - engine = self.engine - - event.listen(engine, "engine_connect", engines._connect_ping_listener) - - real_do_execute = engine.dialect.do_execute - counter = itertools.count(1) - - def fake_do_execute(self, *arg, **kw): - if next(counter) > num_disconnects: - return real_do_execute(self, *arg, **kw) - else: - raise exception - - with self._dbapi_fixture(dialect_name): - with test_utils.nested( - mock.patch.object(engine.dialect, - "do_execute", - fake_do_execute), - mock.patch.object(engine.dialect, - "is_disconnect", - mock.Mock(return_value=is_disconnect)) - ): - yield +class TestDBDisconnectedFixture(TestsExceptionFilter): + native_pre_ping = False def _test_ping_listener_disconnected( self, dialect_name, exc_obj, is_disconnect=True, ): - with self._fixture(dialect_name, exc_obj, 1, is_disconnect): - conn = self.engine.connect() + with self._fixture( + dialect_name, exc_obj, False, is_disconnect, + ) as engine: + conn = engine.connect() with conn.begin(): self.assertEqual( 1, conn.execute(sqla.select(1)).scalars().first(), @@ -1233,19 +1208,145 @@ class TestDBDisconnected(TestsExceptionFilter): self.assertFalse(conn.invalidated) self.assertTrue(conn.in_transaction()) - with self._fixture(dialect_name, exc_obj, 2, is_disconnect): + with self._fixture( + dialect_name, exc_obj, True, is_disconnect, + ) as engine: self.assertRaises( exception.DBConnectionError, - self.engine.connect + engine.connect ) # test implicit execution - with self._fixture(dialect_name, exc_obj, 1): - with self.engine.connect() as conn: + with self._fixture(dialect_name, exc_obj, False) as engine: + with engine.connect() as conn: self.assertEqual( 1, conn.execute(sqla.select(1)).scalars().first(), ) + @contextlib.contextmanager + def _fixture( + self, + dialect_name, + exception, + db_stays_down, + is_disconnect=True, + ): + """Fixture for testing the ping listener. + + For SQLAlchemy 2.0, the mocking is placed more deeply in the + stack within the DBAPI connection / cursor so that we can also + effectively mock out the "pre ping" condition. + + :param dialect_name: dialect to use. "postgresql" or "mysql" + :param exception: an exception class to raise + :param db_stays_down: if True, the database will stay down after the + first ping fails + :param is_disconnect: whether or not the SQLAlchemy dialect should + consider the exception object as a "disconnect error". Openstack's + own exception handlers upgrade various DB exceptions to be + "disconnect" scenarios that SQLAlchemy itself does not, such as + some specific Galera error messages. + + The importance of an exception being a "disconnect error" means that + SQLAlchemy knows it can discard the connection and then reconnect. + If the error is not a "disconnection error", then it raises. + """ + connect_args = {} + patchers = [] + db_disconnected = False + + class DisconnectCursorMixin: + def execute(self, *arg, **kw): + if db_disconnected: + raise exception + else: + return super().execute(*arg, **kw) + + if dialect_name == "postgresql": + import psycopg2.extensions + + class Curs(DisconnectCursorMixin, psycopg2.extensions.cursor): + pass + + connect_args = {"cursor_factory": Curs} + + elif dialect_name == "mysql": + import pymysql + + def fake_ping(self, *arg, **kw): + if db_disconnected: + raise exception + else: + return True + + class Curs(DisconnectCursorMixin, pymysql.cursors.Cursor): + pass + + connect_args = {"cursorclass": Curs} + + patchers.append( + mock.patch.object( + pymysql.Connection, "ping", fake_ping + ) + ) + else: + raise NotImplementedError() + + with mock.patch.object( + compat, + "native_pre_ping_event_support", + self.native_pre_ping, + ): + engine = engines.create_engine( + self.engine.url, max_retries=0) + + # 1. override how we connect. if we want the DB to be down + # for the moment, but recover, reset db_disconnected after + # connect is called. If we want the DB to stay down, then + # make sure connect raises the error also. + @event.listens_for(engine, "do_connect") + def _connect(dialect, connrec, cargs, cparams): + nonlocal db_disconnected + + # while we're here, add our cursor classes to the DBAPI + # connect args + cparams.update(connect_args) + + if db_disconnected: + if db_stays_down: + raise exception + else: + db_disconnected = False + + # 2. initialize the dialect with a first connect + conn = engine.connect() + conn.close() + + # 3. add additional patchers + patchers.extend([ + mock.patch.object( + engine.dialect.dbapi, + "Error", + self.Error, + ), + mock.patch.object( + engine.dialect, + "is_disconnect", + mock.Mock(return_value=is_disconnect), + ), + ]) + + with test_utils.nested(*patchers): + # "disconnect" the DB + db_disconnected = True + yield engine + + +class MySQLPrePingHandlerTests( + db_test_base._MySQLOpportunisticTestCase, + TestDBDisconnectedFixture, +): + def test_mariadb_error_1927(self): for code in [1927]: self._test_ping_listener_disconnected( @@ -1298,6 +1399,26 @@ class TestDBDisconnected(TestsExceptionFilter): is_disconnect=False ) + def test_mysql_w_disconnect_flag(self): + for code in [2002, 2003, 2002]: + self._test_ping_listener_disconnected( + "mysql", + self.OperationalError('%d MySQL server has gone away' % code) + ) + + def test_mysql_wo_disconnect_flag(self): + for code in [2002, 2003]: + self._test_ping_listener_disconnected( + "mysql", + self.OperationalError('%d MySQL server has gone away' % code), + is_disconnect=False + ) + + +class PostgreSQLPrePingHandlerTests( + db_test_base._PostgreSQLOpportunisticTestCase, + TestDBDisconnectedFixture): + def test_postgresql_ping_listener_disconnected(self): self._test_ping_listener_disconnected( "postgresql", @@ -1314,79 +1435,18 @@ class TestDBDisconnected(TestsExceptionFilter): ) -class TestDBConnectRetry(TestsExceptionFilter): +if compat.sqla_2: + class MySQLNativePrePingTests(MySQLPrePingHandlerTests): + native_pre_ping = True - def _run_test(self, dialect_name, exception, count, retries): - counter = itertools.count() - - engine = self.engine - - # empty out the connection pool - engine.dispose() - - connect_fn = engine.dialect.connect - - def cant_connect(*arg, **kw): - if next(counter) < count: - raise exception - else: - return connect_fn(*arg, **kw) - - with self._dbapi_fixture(dialect_name): - with mock.patch.object(engine.dialect, "connect", cant_connect): - return engines._test_connection(engine, retries, .01) - - def test_connect_no_retries(self): - conn = self._run_test( - "mysql", - self.OperationalError("Error: (2003) something wrong"), - 2, 0 - ) - # didnt connect because nothing was tried - self.assertIsNone(conn) - - def test_connect_inifinite_retries(self): - conn = self._run_test( - "mysql", - self.OperationalError("Error: (2003) something wrong"), - 2, -1 - ) - # conn is good - self.assertEqual(1, conn.scalar(sqla.select(1))) - - def test_connect_retry_past_failure(self): - conn = self._run_test( - "mysql", - self.OperationalError("Error: (2003) something wrong"), - 2, 3 - ) - # conn is good - self.assertEqual(1, conn.scalar(sqla.select(1))) - - def test_connect_retry_not_candidate_exception(self): - self.assertRaises( - sqla.exc.OperationalError, # remember, we pass OperationalErrors - # through at the moment :) - self._run_test, - "mysql", - self.OperationalError("Error: (2015) I can't connect period"), - 2, 3 - ) - - def test_connect_retry_stops_infailure(self): - self.assertRaises( - exception.DBConnectionError, - self._run_test, - "mysql", - self.OperationalError("Error: (2003) something wrong"), - 3, 2 - ) + class PostgreSQLNativePrePingTests(PostgreSQLPrePingHandlerTests): + native_pre_ping = True -class TestDBConnectPingWrapping(TestsExceptionFilter): +class TestDBConnectPingListener(TestsExceptionFilter): def setUp(self): - super(TestDBConnectPingWrapping, self).setUp() + super().setUp() event.listen( self.engine, "engine_connect", engines._connect_ping_listener) @@ -1475,6 +1535,75 @@ class TestDBConnectPingWrapping(TestsExceptionFilter): ) +class TestDBConnectRetry(TestsExceptionFilter): + + def _run_test(self, dialect_name, exception, count, retries): + counter = itertools.count() + + engine = self.engine + + # empty out the connection pool + engine.dispose() + + connect_fn = engine.dialect.connect + + def cant_connect(*arg, **kw): + if next(counter) < count: + raise exception + else: + return connect_fn(*arg, **kw) + + with self._dbapi_fixture(dialect_name): + with mock.patch.object(engine.dialect, "connect", cant_connect): + return engines._test_connection(engine, retries, .01) + + def test_connect_no_retries(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, 0 + ) + # didnt connect because nothing was tried + self.assertIsNone(conn) + + def test_connect_inifinite_retries(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, -1 + ) + # conn is good + self.assertEqual(1, conn.scalar(sqla.select(1))) + + def test_connect_retry_past_failure(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, 3 + ) + # conn is good + self.assertEqual(1, conn.scalar(sqla.select(1))) + + def test_connect_retry_not_candidate_exception(self): + self.assertRaises( + sqla.exc.OperationalError, # remember, we pass OperationalErrors + # through at the moment :) + self._run_test, + "mysql", + self.OperationalError("Error: (2015) I can't connect period"), + 2, 3 + ) + + def test_connect_retry_stops_infailure(self): + self.assertRaises( + exception.DBConnectionError, + self._run_test, + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 3, 2 + ) + + class TestsErrorHandler(TestsExceptionFilter): def test_multiple_error_handlers(self): handler = mock.MagicMock(return_value=None)