diff --git a/oslo_db/options.py b/oslo_db/options.py index 111ad244..f55b76aa 100644 --- a/oslo_db/options.py +++ b/oslo_db/options.py @@ -40,6 +40,22 @@ database_opts = [ 'slave database.' ), ), + cfg.StrOpt( + 'asyncio_connection', + help=( + 'The SQLAlchemy asyncio connection string to use to connect to ' + 'the database.' + ), + secret=True, + ), + cfg.StrOpt( + 'asyncio_slave_connection', + help=( + 'The SQLAlchemy asyncio connection string to use to connect to ' + 'the slave database.' + ), + secret=True, + ), cfg.StrOpt( 'mysql_sql_mode', default='TRADITIONAL', diff --git a/oslo_db/sqlalchemy/asyncio_facade.py b/oslo_db/sqlalchemy/asyncio_facade.py new file mode 100644 index 00000000..c762e002 --- /dev/null +++ b/oslo_db/sqlalchemy/asyncio_facade.py @@ -0,0 +1,510 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import contextlib +import contextvars +import functools +import inspect +import itertools +import logging +import operator +from typing import AsyncContextManager +from typing import AsyncIterator +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from oslo_db import exception +from oslo_db.sqlalchemy import engines +from oslo_utils import excutils +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.util.concurrency import await_only +from sqlalchemy.util.concurrency import greenlet_spawn + +from . import enginefacade as _sync_facade +from .enginefacade import _AbstractTransactionContext +from .enginefacade import _AbstractTransactionContextManager +from .enginefacade import _AbstractTransactionFactory + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection + from sqlalchemy.ext.asyncio import AsyncSession + +_AES = TypeVar("_AES", bound=Union["AsyncConnection", "AsyncSession", None]) + +LOG = logging.getLogger(__name__) + + +def _create_async_engine(url, **engine_args): + async_engine = create_async_engine(url, **engine_args) + return async_engine, async_engine.engine + + +def _test_async_connection(engine, max_retries, retry_interval): + if max_retries == -1: + attempts = itertools.count() + else: + attempts = range(max_retries) + # See: http://legacy.python.org/dev/peps/pep-3110/#semantic-changes for + # why we are not using 'de' directly (it can be removed from the local + # scope). + de_ref = None + for attempt in attempts: + try: + conn = engine.connect() + except exception.DBConnectionError as de: + msg = "SQL async connection failed. %s attempts left." + LOG.warning(msg, max_retries - attempt) + await_only(asyncio.sleep(retry_interval)) + de_ref = de + else: + conn.close() + else: + if de_ref is not None: + raise de_ref + + +class _GreenletAdaptedLock: + def __init__(self): + self._lock = asyncio.Lock() + + def __enter__(self): + await_only(self._lock.acquire()) + + def __exit__(self, *arg, **kw): + self._lock.release() + + +class _AsyncioTransactionFactory(_AbstractTransactionFactory): + """A factory for :class:`._AsyncioTransactionContext` objects. + + By default, there is just one of these, set up + based on CONF, however instance-level :class:`._TransactionFactory` + objects can be made, as is the case with the + :class:`._TestTransactionFactory` subclass used by the oslo.db test suite. + """ + + _start_lock: _GreenletAdaptedLock + + _writer_engine: AsyncEngine + _reader_engine: AsyncEngine + _writer_maker: async_sessionmaker + _reader_maker: async_sessionmaker + + def _make_lock(self): + return _GreenletAdaptedLock() + + if TYPE_CHECKING: + + def get_writer_engine(self) -> AsyncEngine: + """Return the writer engine for this factory. + + Implies start. + """ + ... + + def get_reader_engine(self) -> AsyncEngine: + """Return the reader engine for this factory. + + Implies start. + """ + ... + + def get_writer_maker(self) -> async_sessionmaker: + """Return the writer sessionmaker for this factory. + + Implies start. + """ + ... + + def get_reader_maker(self) -> async_sessionmaker: + """Return the reader sessionmaker for this factory. + + Implies start. + """ + ... + + async def dispose_pool_async(self): + """Call engine.dispose() on underlying AsyncEngine objects.""" + + async with self._start_lock._lock: + if not self._started: + return + + await self._writer_engine.dispose() + if self._reader_engine is not self._writer_engine: + await self._reader_engine.dispose() + + def _setup_for_connection( + self, + sql_connection, + engine_kwargs, + maker_kwargs, + ): + if sql_connection is None: + raise exception.CantStartEngineError( + "No async sql_connection parameter is established" + ) + + engine = engines.create_engine( + sql_connection=sql_connection, + _engine_target=_create_async_engine, + _test_connection=_test_async_connection, + **engine_kwargs, + ) + for hook in self._facade_cfg["on_engine_create"]: + hook(engine) + sessionmaker = async_sessionmaker(bind=engine, **maker_kwargs) + return engine, sessionmaker + + async def _create_async_connection(self, mode): + if not self._started: + await greenlet_spawn(self._start) + if mode is _sync_facade._WRITER: + return await self._writer_engine.connect() + elif mode is _sync_facade._ASYNC_READER or ( + mode is _sync_facade._READER and not self.synchronous_reader + ): + return await self._reader_engine.connect() + else: + return await self._writer_engine.connect() + + async def _create_async_session(self, mode, bind=None): + if not self._started: + await greenlet_spawn(self._start) + kw = {} + # don't pass 'bind' if bind is None; the sessionmaker + # already has a bind to the engine. + if bind: + kw["bind"] = bind + if mode is _sync_facade._WRITER: + return self._writer_maker(**kw) + elif mode is _sync_facade._ASYNC_READER or ( + mode is _sync_facade._READER and not self.synchronous_reader + ): + return self._reader_maker(**kw) + else: + return self._writer_maker(**kw) + + +class _TransactionContextContextVar: + __slots__ = ("_context_vars",) + + def __init__(self): + object.__setattr__(self, "_context_vars", {}) + + def _get_context_var(self, key): + if key not in self._context_vars: + self._context_vars[key] = var = contextvars.ContextVar(key) + else: + var = self._context_vars[key] + return var + + def __getattr__(self, key): + context_var = self._get_context_var(key) + try: + return context_var.get() + except LookupError as le: + raise AttributeError(key) from le + + def __setattr__(self, key, value): + context_var = self._get_context_var(key) + context_var.set(value) + + def __delattr__(self, key): + context_var = self._get_context_var(key) + context_var.set(None) + + def __deepcopy__(self, memo): + return self + + def __reduce__(self): + return _TransactionContextContextVar, () + + +def _async_transaction_ctx_for_context(context): + by_coroutine = _transaction_contexts_by_coroutine(context) + try: + return by_coroutine.current + except AttributeError: + raise exception.NoEngineContextEstablished( + "No AsyncTransactionContext is established for " + f"this {context} object within the current thread. " + "Ensure this Context class participates in the " + "async_transaction_context_provider() class decorator." + ) + + +def _transaction_contexts_by_coroutine(context): + transaction_contexts_by_coroutine = getattr( + context, "_asyncio_facade_context", None + ) + if transaction_contexts_by_coroutine is None: + transaction_contexts_by_coroutine = context._asyncio_facade_context = ( + _TransactionContextContextVar() + ) + + return transaction_contexts_by_coroutine + + +class _AsyncTransactionContextManager( + _AbstractTransactionContextManager[_AES] +): + + def using(self, context) -> AsyncContextManager[_AES]: + """Provide a context manager block that will use the given context.""" + return self._async_transaction_scope(context) + + @property + def connection(self) -> _AsyncTransactionContextManager[AsyncConnection]: + """Modifier to return a core Connection object instead of Session.""" + return self._clone(connection=True) + + def __call__(self, fn): + """Decorate an awaitable function.""" + argspec = inspect.getfullargspec(fn) + if argspec.args[0] == "self" or argspec.args[0] == "cls": + context_index = 1 + else: + context_index = 0 + context_kw = argspec.args[context_index] + + @functools.wraps(fn) + async def wrapper(*args, **kwargs): + context = kwargs.get(context_kw, None) + if not context: + context = args[context_index] + + async with self._async_transaction_scope(context): + return await fn(*args, **kwargs) + + return wrapper + + @contextlib.asynccontextmanager + async def _async_transaction_scope(self, context) -> AsyncIterator[_AES]: + transaction_contexts_by_coroutine = _transaction_contexts_by_coroutine( + context + ) + + current, restore = self._transaction_scope_impl( + transaction_contexts_by_coroutine + ) + + try: + if self._mode is not None: + async with current._produce_asyncio_block( + mode=self._mode, + connection=self._connection, + savepoint=self._savepoint, + allow_async=self._allow_async, + context=context, + ) as resource: + yield resource + else: + yield + finally: + if restore is None: + del transaction_contexts_by_coroutine.current + elif current is not restore: + transaction_contexts_by_coroutine.current = restore + + def _create_root_factory(self): + return _AsyncioTransactionFactory() + + def _create_transaction_context(self, use_factory, global_factory): + return _AsyncTransactionContext( + use_factory, global_factory=global_factory + ) + + +class _AsyncTransactionContext(_AbstractTransactionContext): + + @property + def async_connection(self): + return self.connection + + @property + def async_session(self): + return self.session + + async def _end_async_session_transaction(self, session): + if self.mode is _sync_facade._WRITER: + await session.commit() + elif self.rollback_reader_sessions: + await session.rollback() + # In the absence of calling session.rollback(), + # the next call is session.close(). This releases all + # objects from the session into the detached state, and + # releases the connection as well; the connection when returned + # to the pool is either rolled back in any case, or closed fully. + + async def _end_async_connection_transaction(self, transaction): + if self.mode is _sync_facade._WRITER: + await transaction.commit() + else: + await transaction.rollback() + + def _produce_asyncio_block( + self, mode, connection, savepoint, allow_async=False, context=None + ): + if mode is _sync_facade._WRITER: + self._writer() + elif mode is _sync_facade._ASYNC_READER: + self._async_reader() + else: + self._reader(allow_async) + if connection: + return self._async_connection(savepoint, context=context) + else: + return self._async_session(savepoint, context=context) + + @contextlib.asynccontextmanager + async def _async_connection(self, savepoint=False, context=None): + if self.connection is None: + try: + if self.session is not None: + # use existing session, which is outer to us + self.connection = await self.session.connection() + if savepoint: + with self.connection.begin_nested(), self._add_context( + self.connection, context + ): + yield self.connection + else: + with self._add_context(self.connection, context): + yield self.connection + else: + # is outermost + self.connection = ( + await self.factory._create_async_connection( + mode=self.mode + ) + ) + self.transaction = await self.connection.begin() + try: + with self._add_context(self.connection, context): + yield self.connection + await self._end_async_connection_transaction( + self.transaction + ) + except Exception: + await self.transaction.rollback() + # TODO(zzzeek) do we need save_and_reraise() here, + # or do newer eventlets not have issues? we are using + # raw "raise" in many other places in oslo.db already + raise + finally: + self.transaction = None + await self.connection.close() + finally: + self.connection = None + + else: + # use existing connection, which is outer to us + if savepoint: + async with self.connection.begin_nested(): + with self._add_context(self.connection, context): + yield self.connection + else: + with self._add_context(self.connection, context): + yield self.connection + + @contextlib.asynccontextmanager + async def _async_session(self, savepoint=False, context=None): + if self.session is None: + self.session = await self.factory._create_async_session( + bind=self.connection, mode=self.mode + ) + try: + await self.session.begin() + with self._add_context(self.session, context): + yield self.session + await self._end_async_session_transaction(self.session) + except Exception: + with excutils.save_and_reraise_exception(): + await self.session.rollback() + finally: + await self.session.close() + self.session = None + else: + # use existing session, which is outer to us + if savepoint: + async with self.session.begin_nested(): + with self._add_context(self.session, context): + yield self.session + else: + with self._add_context(self.session, context): + yield self.session + if self.flush_on_subtransaction: + await self.session.flush() + + +def configure(**kw): + """Apply configurational options to the global factory. + + This method can only be called before any specific transaction-beginning + methods have been called. + + .. seealso:: + + :meth:`._TransactionFactory.configure` + + """ + _async_context_manager._factory.configure(**kw) + + +def async_transaction_context_provider(klass): + """Decorate a class with ``session`` and ``connection`` attributes.""" + + setattr( + klass, + "async_transaction_ctx", + property(_async_transaction_ctx_for_context), + ) + + # Graft transaction context attributes as context properties + for attr in ("async_session", "async_connection", "async_transaction"): + setattr( + klass, + attr, + _sync_facade._context_descriptor( + attr, + transaction_ctx_getter=operator.attrgetter( + "async_transaction_ctx" + ), + ), + ) + + return klass + + +_async_context_manager: _AsyncTransactionContextManager[AsyncSession] = ( + _AsyncTransactionContextManager(_is_global_manager=True) +) +"""default context manager.""" + + +def transaction_context(): + """Construct a local transaction context.""" + return _AsyncTransactionContextManager() + + +reader = _async_context_manager.reader +"""The global 'reader' starting point.""" + + +writer = _async_context_manager.writer +"""The global 'writer' starting point.""" diff --git a/oslo_db/sqlalchemy/enginefacade.py b/oslo_db/sqlalchemy/enginefacade.py index dccc6d6e..de260ecf 100644 --- a/oslo_db/sqlalchemy/enginefacade.py +++ b/oslo_db/sqlalchemy/enginefacade.py @@ -9,60 +9,91 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +from __future__ import annotations import contextlib +from enum import Enum import functools import inspect import operator import threading +from typing import ContextManager +from typing import Generic +from typing import Protocol +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union import warnings import debtcollector.removals import debtcollector.renames from oslo_config import cfg -from oslo_utils import excutils - from oslo_db import exception from oslo_db import options from oslo_db.sqlalchemy import engines from oslo_db.sqlalchemy import orm from oslo_db import warning +from oslo_utils import excutils -class _symbol(object): - """represent a fixed symbol.""" +if TYPE_CHECKING: + from typing import Any + from typing import Self - __slots__ = 'name', + from sqlalchemy.engine import Connection + from sqlalchemy.engine import Engine + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.ext.asyncio import AsyncConnection + from sqlalchemy.ext.asyncio import AsyncEngine + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm import Session + from sqlalchemy.orm import sessionmaker - def __init__(self, name): - self.name = name +_AllEngineSessionT = TypeVar( + "_AllEngineSessionT", + bound=Union[ + "Connection", "Session", "AsyncConnection", "AsyncSession", None + ], +) - def __repr__(self): - return "symbol(%r)" % self.name +_SyncEngineSessionT = TypeVar( + "_SyncEngineSessionT", bound=Union["Connection", "Session", None] +) -_ASYNC_READER = _symbol('ASYNC_READER') -"""Represent the transaction state of "async reader". +class _FacadeSymbols(Enum): + _ASYNC_READER = 1 + """Represent the transaction state of "async reader". -This state indicates that the transaction is a read-only and is -safe to use on an asynchronously updated slave database. -""" + This state indicates that the transaction is a read-only and is + safe to use on an asynchronously updated slave database. + """ -_READER = _symbol('READER') -"""Represent the transaction state of "reader". + _READER = 2 + """Represent the transaction state of "reader". -This state indicates that the transaction is a read-only and is -only safe to use on a synchronously updated slave database; otherwise -the master database should be used. -""" + This state indicates that the transaction is a read-only and is + only safe to use on a synchronously updated slave database; otherwise + the master database should be used. + """ + + _WRITER = 3 + """Represent the transaction state of "writer". + + This state indicates that the transaction writes data and + should be directed at the master database. + """ -_WRITER = _symbol('WRITER') -"""Represent the transaction state of "writer". +_ASYNC_READER, _READER, _WRITER = ( + _FacadeSymbols._ASYNC_READER, + _FacadeSymbols._READER, + _FacadeSymbols._WRITER, +) -This state indicates that the transaction writes data and -should be directed at the master database. -""" + +class _NotSet(Enum): + NOTSET = 1 class _Default: @@ -74,9 +105,9 @@ class _Default: will supersede those in cfg.CONF. """ - __slots__ = 'value', + __slots__ = ("value",) - _notset = _symbol("NOTSET") + _notset = _NotSet.NOTSET def __init__(self, value=_notset): self.value = value @@ -119,7 +150,7 @@ class _Default: # oslo.config doesn't provide a public API to retrieve the opt # itself, as opposed to the value of the opt :( - opt = conf.database._group._opts[key]['opt'] + opt = conf.database._group._opts[key]["opt"] # ditto for the group group = conf.database._group if ( @@ -141,67 +172,92 @@ class AlreadyStartedError(TypeError): """ -class _TransactionFactory: - """A factory for :class:`._TransactionContext` objects. +class _EnterableLock(Protocol): + def __enter__(self): + ... + + def __exit__(self, *arg, **kw): + ... + + +class _AbstractTransactionFactory: + """Abstract base for _TransactionFactory.""" + + _start_lock: _EnterableLock + _writer_engine: Engine | AsyncEngine + _reader_engine: Engine | AsyncEngine + _writer_maker: sessionmaker | async_sessionmaker + _reader_maker: sessionmaker | async_sessionmaker - By default, there is just one of these, set up - based on CONF, however instance-level :class:`._TransactionFactory` - objects can be made, as is the case with the - :class:`._TestTransactionFactory` subclass used by the oslo.db test suite. - """ def __init__(self): self._url_cfg = { - 'connection': _Default(), - 'slave_connection': _Default(), + "connection": _Default(), + "slave_connection": _Default(), + "asyncio_connection": _Default(), + "asyncio_slave_connection": _Default(), } self._engine_cfg = { - 'sqlite_fk': _Default(False), - 'mysql_sql_mode': _Default('TRADITIONAL'), - 'mysql_wsrep_sync_wait': _Default(), - 'connection_recycle_time': _Default(3600), - 'connection_debug': _Default(0), - 'max_pool_size': _Default(), - 'max_overflow': _Default(), - 'pool_timeout': _Default(), - 'sqlite_synchronous': _Default(True), - 'connection_trace': _Default(False), - 'max_retries': _Default(10), - 'retry_interval': _Default(10), - 'thread_checkin': _Default(True), - 'json_serializer': _Default(None), - 'json_deserializer': _Default(None), - 'logging_name': _Default(None), - 'connection_parameters': _Default(None) + "sqlite_fk": _Default(False), + "mysql_sql_mode": _Default("TRADITIONAL"), + "mysql_wsrep_sync_wait": _Default(), + "connection_recycle_time": _Default(3600), + "connection_debug": _Default(0), + "max_pool_size": _Default(), + "max_overflow": _Default(), + "pool_timeout": _Default(), + "sqlite_synchronous": _Default(True), + "connection_trace": _Default(False), + "max_retries": _Default(10), + "retry_interval": _Default(10), + "thread_checkin": _Default(True), + "json_serializer": _Default(None), + "json_deserializer": _Default(None), + "logging_name": _Default(None), + "connection_parameters": _Default(None), } self._maker_cfg = { - 'expire_on_commit': _Default(False), + "expire_on_commit": _Default(False), } self._transaction_ctx_cfg = { - 'rollback_reader_sessions': False, - 'flush_on_subtransaction': False, - } + "rollback_reader_sessions": False, + "flush_on_subtransaction": False, + } self._facade_cfg = { - 'synchronous_reader': True, - 'on_engine_create': [], + "synchronous_reader": True, + "on_engine_create": [], } # other options that are defined in oslo_db.options.database_opts # but do not apply to the standard enginefacade arguments (most seem # to apply to api.DBAPI). self._ignored_cfg = dict( - (k, _Default(None)) for k in [ - 'db_max_retries', 'db_inc_retry_interval', - 'use_db_reconnect', - 'db_retry_interval', - 'db_max_retry_interval', 'backend', + (k, _Default(None)) + for k in [ + "db_max_retries", + "db_inc_retry_interval", + "use_db_reconnect", + "db_retry_interval", + "db_max_retry_interval", + "backend", ] ) self._started = False self._legacy_facade = None - self._start_lock = threading.Lock() + self._start_lock = self._make_lock() - def configure_defaults(self, **kw): + def _make_lock(self) -> _EnterableLock: + raise NotImplementedError() + + def _setup_for_connection( + self, + sql_connection, + engine_kwargs, + maker_kwargs, + ): + raise NotImplementedError() + + def configure_defaults(self, **kw: Any) -> None: """Apply default configurational options. This method can only be called before any specific @@ -224,6 +280,8 @@ class _TransactionFactory: :param connection: database URL :param slave_connection: database URL + :param asyncio_connection: database URL + :param asyncio_slave_connection: database URL :param sqlite_fk: whether to enable SQLite foreign key pragma; default False :param mysql_sql_mode: MySQL SQL mode, defaults to TRADITIONAL @@ -289,7 +347,7 @@ class _TransactionFactory: """ self._configure(True, kw) - def configure(self, **kw): + def configure(self, **kw: Any) -> None: """Apply configurational options. This method can only be called before any specific @@ -316,9 +374,12 @@ class _TransactionFactory: not_supported = [] for k, v in kw.items(): for dict_ in ( - self._url_cfg, self._engine_cfg, - self._maker_cfg, self._ignored_cfg, - self._facade_cfg, self._transaction_ctx_cfg, + self._url_cfg, + self._engine_cfg, + self._maker_cfg, + self._ignored_cfg, + self._facade_cfg, + self._transaction_ctx_cfg, ): if k in dict_: dict_[k] = _Default(v) if as_defaults else v @@ -331,27 +392,12 @@ class _TransactionFactory: # too many unrecognized (obsolete?) configuration options # coming in from projects warnings.warn( - "Configuration option(s) %r not supported" % - sorted(not_supported), - warning.NotSupportedWarning + "Configuration option(s) %r not supported" + % sorted(not_supported), + warning.NotSupportedWarning, ) - def get_legacy_facade(self): - """Return a :class:`.LegacyEngineFacade` for this factory. - - This facade will make use of the same engine and sessionmaker - as this factory, however will not share the same transaction context; - the legacy facade continues to work the old way of returning - a new Session each time get_session() is called. - """ - if not self._legacy_facade: - self._legacy_facade = LegacyEngineFacade(None, _factory=self) - if not self._started: - self._start() - - return self._legacy_facade - - def get_writer_engine(self): + def get_writer_engine(self) -> Engine | AsyncEngine: """Return the writer engine for this factory. Implies start. @@ -360,7 +406,7 @@ class _TransactionFactory: self._start() return self._writer_engine - def get_reader_engine(self): + def get_reader_engine(self) -> Engine | AsyncEngine: """Return the reader engine for this factory. Implies start. @@ -369,7 +415,7 @@ class _TransactionFactory: self._start() return self._reader_engine - def get_writer_maker(self): + def get_writer_maker(self) -> sessionmaker | async_sessionmaker: """Return the writer sessionmaker for this factory. Implies start. @@ -378,7 +424,7 @@ class _TransactionFactory: self._start() return self._writer_maker - def get_reader_maker(self): + def get_reader_maker(self) -> sessionmaker | async_sessionmaker: """Return the reader sessionmaker for this factory. Implies start. @@ -387,13 +433,23 @@ class _TransactionFactory: self._start() return self._reader_maker + def _create_factory_copy(self) -> Self: + factory = self.__class__() + factory._url_cfg.update(self._url_cfg) + factory._engine_cfg.update(self._engine_cfg) + factory._maker_cfg.update(self._maker_cfg) + factory._transaction_ctx_cfg.update(self._transaction_ctx_cfg) + factory._facade_cfg.update(self._facade_cfg) + return factory + def _create_connection(self, mode): if not self._started: self._start() if mode is _WRITER: return self._writer_engine.connect() - elif mode is _ASYNC_READER or \ - (mode is _READER and not self.synchronous_reader): + elif mode is _ASYNC_READER or ( + mode is _READER and not self.synchronous_reader + ): return self._reader_engine.connect() else: return self._writer_engine.connect() @@ -405,24 +461,16 @@ class _TransactionFactory: # don't pass 'bind' if bind is None; the sessionmaker # already has a bind to the engine. if bind: - kw['bind'] = bind + kw["bind"] = bind if mode is _WRITER: return self._writer_maker(**kw) - elif mode is _ASYNC_READER or \ - (mode is _READER and not self.synchronous_reader): + elif mode is _ASYNC_READER or ( + mode is _READER and not self.synchronous_reader + ): return self._reader_maker(**kw) else: return self._writer_maker(**kw) - def _create_factory_copy(self): - factory = _TransactionFactory() - factory._url_cfg.update(self._url_cfg) - factory._engine_cfg.update(self._engine_cfg) - factory._maker_cfg.update(self._maker_cfg) - factory._transaction_ctx_cfg.update(self._transaction_ctx_cfg) - factory._facade_cfg.update(self._facade_cfg) - return factory - def _args_for_conf(self, default_cfg, conf): if conf is None: return { @@ -447,16 +495,6 @@ class _TransactionFactory: maker_args = self._args_for_conf(self._maker_cfg, conf) return maker_args - def dispose_pool(self): - """Call engine.pool.dispose() on underlying Engine objects.""" - with self._start_lock: - if not self._started: - return - - self._writer_engine.pool.dispose() - if self._reader_engine is not self._writer_engine: - self._reader_engine.pool.dispose() - @property def is_started(self): """True if this :class:`._TransactionFactory` is already started.""" @@ -478,46 +516,138 @@ class _TransactionFactory: # the cfg.CONF to maintain exact compatibility with # the EngineFacade design. This can be changed if needed. if conf is not None: - conf.register_opts(options.database_opts, 'database') + conf.register_opts(options.database_opts, "database") url_args = self._url_args_for_conf(conf) if connection: - url_args['connection'] = connection + url_args["connection"] = connection if slave_connection: - url_args['slave_connection'] = slave_connection + url_args["slave_connection"] = slave_connection engine_args = self._engine_args_for_conf(conf) maker_args = self._maker_args_for_conf(conf) - self._writer_engine, self._writer_maker = \ + self._writer_engine, self._writer_maker = ( self._setup_for_connection( - url_args['connection'], - engine_args, maker_args) + url_args["connection"], engine_args, maker_args + ) + ) - if url_args.get('slave_connection'): - self._reader_engine, self._reader_maker = \ + if url_args.get("slave_connection"): + self._reader_engine, self._reader_maker = ( self._setup_for_connection( - url_args['slave_connection'], - engine_args, maker_args) + url_args["slave_connection"], engine_args, maker_args + ) + ) else: - self._reader_engine, self._reader_maker = \ - self._writer_engine, self._writer_maker + self._reader_engine, self._reader_maker = ( + self._writer_engine, + self._writer_maker, + ) - self.synchronous_reader = self._facade_cfg['synchronous_reader'] + self.synchronous_reader = self._facade_cfg["synchronous_reader"] # set up _started last, so that in case of exceptions # we try the whole thing again and report errors # correctly self._started = True + +class _TransactionFactory(_AbstractTransactionFactory): + """A factory for :class:`._TransactionContext` objects. + + By default, there is just one of these, set up + based on CONF, however instance-level :class:`._TransactionFactory` + objects can be made, as is the case with the + :class:`._TestTransactionFactory` subclass used by the oslo.db test suite. + """ + + _start_lock: threading.Lock + _writer_engine: Engine + _reader_engine: Engine + _writer_maker: sessionmaker + _reader_maker: sessionmaker + + def _make_lock(self) -> threading.Lock: + return threading.Lock() + + if TYPE_CHECKING: + + def get_writer_engine(self) -> Engine: + """Return the writer engine for this factory. + + Implies start. + """ + ... + + def get_reader_engine(self) -> Engine: + """Return the reader engine for this factory. + + Implies start. + """ + ... + + def get_writer_maker(self) -> sessionmaker: + """Return the writer sessionmaker for this factory. + + Implies start. + """ + ... + + def get_reader_maker(self) -> sessionmaker: + """Return the reader sessionmaker for this factory. + + Implies start. + """ + ... + + def _create_factory_copy(self): + factory = _TransactionFactory() + factory._url_cfg.update(self._url_cfg) + factory._engine_cfg.update(self._engine_cfg) + factory._maker_cfg.update(self._maker_cfg) + factory._transaction_ctx_cfg.update(self._transaction_ctx_cfg) + factory._facade_cfg.update(self._facade_cfg) + return factory + + def get_legacy_facade(self): + """Return a :class:`.LegacyEngineFacade` for this factory. + + This facade will make use of the same engine and sessionmaker + as this factory, however will not share the same transaction context; + the legacy facade continues to work the old way of returning + a new Session each time get_session() is called. + """ + if not self._legacy_facade: + self._legacy_facade = LegacyEngineFacade(None, _factory=self) + if not self._started: + self._start() + + return self._legacy_facade + + def dispose_pool(self): + """Call engine.pool.dispose() on underlying Engine objects.""" + with self._start_lock: + if not self._started: + return + + self._writer_engine.pool.dispose() + if self._reader_engine is not self._writer_engine: + self._reader_engine.pool.dispose() + def _setup_for_connection( - self, sql_connection, engine_kwargs, maker_kwargs, + self, + sql_connection, + engine_kwargs, + maker_kwargs, ): if sql_connection is None: raise exception.CantStartEngineError( - "No sql_connection parameter is established") + "No sql_connection parameter is established" + ) engine = engines.create_engine( - sql_connection=sql_connection, **engine_kwargs) - for hook in self._facade_cfg['on_engine_create']: + sql_connection=sql_connection, **engine_kwargs + ) + for hook in self._facade_cfg["on_engine_create"]: hook(engine) sessionmaker = orm.get_maker(engine=engine, **maker_kwargs) return engine, sessionmaker @@ -541,8 +671,9 @@ class _TestTransactionFactory(_TransactionFactory): """ @debtcollector.removals.removed_kwarg( - 'synchronous_reader', - 'argument value is propagated from the parent _TransactionFactory') + "synchronous_reader", + "argument value is propagated from the parent _TransactionFactory", + ) def __init__(self, engine, maker, apply_global, from_factory=None, **kw): # NOTE(zzzeek): **kw needed for backwards compability self._reader_engine = self._writer_engine = engine @@ -556,7 +687,7 @@ class _TestTransactionFactory(_TransactionFactory): self._facade_cfg = from_factory._facade_cfg self._transaction_ctx_cfg = from_factory._transaction_ctx_cfg - self.synchronous_reader = self._facade_cfg['synchronous_reader'] + self.synchronous_reader = self._facade_cfg["synchronous_reader"] if apply_global: self.existing_factory = _context_manager._factory @@ -566,7 +697,7 @@ class _TestTransactionFactory(_TransactionFactory): _context_manager._root_factory = self.existing_factory -class _TransactionContext(object): +class _AbstractTransactionContext: """Represent a single database transaction in progress.""" def __init__(self, factory, global_factory=None): @@ -586,8 +717,75 @@ class _TransactionContext(object): self.connection = None self.transaction = None kw = self.factory._transaction_ctx_cfg - self.rollback_reader_sessions = kw['rollback_reader_sessions'] - self.flush_on_subtransaction = kw['flush_on_subtransaction'] + self.rollback_reader_sessions = kw["rollback_reader_sessions"] + self.flush_on_subtransaction = kw["flush_on_subtransaction"] + + @contextlib.contextmanager + def _add_context(self, connection, context): + restore_context = connection.info.get("using_context") + connection.info["using_context"] = context + yield connection + connection.info["using_context"] = restore_context + + def _writer(self): + if self.mode is None: + self.mode = _WRITER + elif self.mode is _READER: + raise TypeError( + "Can't upgrade a READER transaction " + "to a WRITER mid-transaction" + ) + elif self.mode is _ASYNC_READER: + raise TypeError( + "Can't upgrade an ASYNC_READER transaction " + "to a WRITER mid-transaction" + ) + + def _reader(self, allow_async=False): + if self.mode is None: + self.mode = _READER + elif self.mode is _ASYNC_READER and not allow_async: + raise TypeError( + "Can't upgrade an ASYNC_READER transaction " + "to a READER mid-transaction" + ) + + def _async_reader(self): + if self.mode is None: + self.mode = _ASYNC_READER + + +class _TransactionContext(_AbstractTransactionContext): + def _end_session_transaction(self, session): + if self.mode is _WRITER: + session.commit() + elif self.rollback_reader_sessions: + session.rollback() + # In the absence of calling session.rollback(), + # the next call is session.close(). This releases all + # objects from the session into the detached state, and + # releases the connection as well; the connection when returned + # to the pool is either rolled back in any case, or closed fully. + + def _end_connection_transaction(self, transaction): + if self.mode is _WRITER: + transaction.commit() + else: + transaction.rollback() + + def _produce_block( + self, mode, connection, savepoint, allow_async=False, context=None + ): + if mode is _WRITER: + self._writer() + elif mode is _ASYNC_READER: + self._async_reader() + else: + self._reader(allow_async) + if connection: + return self._connection(savepoint, context=context) + else: + return self._session(savepoint, context=context) @contextlib.contextmanager def _connection(self, savepoint=False, context=None): @@ -597,8 +795,9 @@ class _TransactionContext(object): # use existing session, which is outer to us self.connection = self.session.connection() if savepoint: - with self.connection.begin_nested(), \ - self._add_context(self.connection, context): + with self.connection.begin_nested(), self._add_context( + self.connection, context + ): yield self.connection else: with self._add_context(self.connection, context): @@ -606,7 +805,8 @@ class _TransactionContext(object): else: # is outermost self.connection = self.factory._create_connection( - mode=self.mode) + mode=self.mode + ) self.transaction = self.connection.begin() try: with self._add_context(self.connection, context): @@ -627,8 +827,9 @@ class _TransactionContext(object): else: # use existing connection, which is outer to us if savepoint: - with self.connection.begin_nested(), \ - self._add_context(self.connection, context): + with self.connection.begin_nested(), self._add_context( + self.connection, context + ): yield self.connection else: with self._add_context(self.connection, context): @@ -638,7 +839,8 @@ class _TransactionContext(object): def _session(self, savepoint=False, context=None): if self.session is None: self.session = self.factory._create_session( - bind=self.connection, mode=self.mode) + bind=self.connection, mode=self.mode + ) try: self.session.begin() with self._add_context(self.session, context): @@ -662,67 +864,6 @@ class _TransactionContext(object): if self.flush_on_subtransaction: self.session.flush() - @contextlib.contextmanager - def _add_context(self, connection, context): - restore_context = connection.info.get('using_context') - connection.info['using_context'] = context - yield connection - connection.info['using_context'] = restore_context - - def _end_session_transaction(self, session): - if self.mode is _WRITER: - session.commit() - elif self.rollback_reader_sessions: - session.rollback() - # In the absence of calling session.rollback(), - # the next call is session.close(). This releases all - # objects from the session into the detached state, and - # releases the connection as well; the connection when returned - # to the pool is either rolled back in any case, or closed fully. - - def _end_connection_transaction(self, transaction): - if self.mode is _WRITER: - transaction.commit() - else: - transaction.rollback() - - def _produce_block(self, mode, connection, savepoint, allow_async=False, - context=None): - if mode is _WRITER: - self._writer() - elif mode is _ASYNC_READER: - self._async_reader() - else: - self._reader(allow_async) - if connection: - return self._connection(savepoint, context=context) - else: - return self._session(savepoint, context=context) - - def _writer(self): - if self.mode is None: - self.mode = _WRITER - elif self.mode is _READER: - raise TypeError( - "Can't upgrade a READER transaction " - "to a WRITER mid-transaction") - elif self.mode is _ASYNC_READER: - raise TypeError( - "Can't upgrade an ASYNC_READER transaction " - "to a WRITER mid-transaction") - - def _reader(self, allow_async=False): - if self.mode is None: - self.mode = _READER - elif self.mode is _ASYNC_READER and not allow_async: - raise TypeError( - "Can't upgrade an ASYNC_READER transaction " - "to a READER mid-transaction") - - def _async_reader(self): - if self.mode is None: - self.mode = _ASYNC_READER - class _TransactionContextTLocal(threading.local): def __deepcopy__(self, memo): @@ -732,7 +873,7 @@ class _TransactionContextTLocal(threading.local): return _TransactionContextTLocal, () -class _TransactionContextManager(object): +class _AbstractTransactionContextManager(Generic[_AllEngineSessionT]): """Provide context-management and decorator patterns for transactions. This object integrates user-defined "context" objects with the @@ -742,18 +883,20 @@ class _TransactionContextManager(object): """ def __init__( - self, root=None, - mode=None, - independent=False, - savepoint=False, - connection=False, - replace_global_factory=None, - _is_global_manager=False, - allow_async=False): + self, + root=None, + mode=None, + independent=False, + savepoint=False, + connection=False, + replace_global_factory=None, + _is_global_manager=False, + allow_async=False, + ): if root is None: self._root = self - self._root_factory = _TransactionFactory() + self._root_factory = self._create_root_factory() else: self._root = root @@ -764,10 +907,17 @@ class _TransactionContextManager(object): self._savepoint = savepoint if self._savepoint and self._independent: raise TypeError( - "setting savepoint and independent makes no sense.") + "setting savepoint and independent makes no sense." + ) self._connection = connection self._allow_async = allow_async + def _create_root_factory(self): + return _TransactionFactory() + + def _create_transaction_context(self, use_factory, global_factory): + return _TransactionContext(use_factory, global_factory=global_factory) + @property def _factory(self): """The :class:`._TransactionFactory` associated with this context.""" @@ -790,7 +940,7 @@ class _TransactionContextManager(object): def append_on_engine_create(self, fn): """Append a listener function to _facade_cfg["on_engine_create"]""" - self._factory._facade_cfg['on_engine_create'].append(fn) + self._factory._facade_cfg["on_engine_create"].append(fn) def get_legacy_facade(self): """Return a :class:`.LegacyEngineFacade` for factory from this context. @@ -852,7 +1002,7 @@ class _TransactionContextManager(object): new._root = new new._root_factory = self._root_factory._create_factory_copy() if new._factory._started: - raise AssertionError('TransactionFactory is already started') + raise AssertionError("TransactionFactory is already started") return new def patch_factory(self, factory_or_manager): @@ -877,9 +1027,10 @@ class _TransactionContextManager(object): else: raise ValueError( "_TransactionContextManager or " - "_TransactionFactory expected.") + "_TransactionFactory expected." + ) if self._root is not self: - raise AssertionError('patch_factory only works for root factory.') + raise AssertionError("patch_factory only works for root factory.") existing_factory = self._root_factory self._root_factory = factory @@ -909,29 +1060,27 @@ class _TransactionContextManager(object): maker = orm.get_maker(engine=engine, **maker_kwargs) factory = _TestTransactionFactory( - engine, maker, - apply_global=False, - from_factory=existing_factory + engine, maker, apply_global=False, from_factory=existing_factory ) return self.patch_factory(factory) @property - def replace(self): + def replace(self) -> Self: """Modifier to replace the global transaction factory with this one.""" return self._clone(replace_global_factory=self._factory) @property - def writer(self): + def writer(self) -> Self: """Modifier to set the transaction to WRITER.""" return self._clone(mode=_WRITER) @property - def reader(self): + def reader(self) -> Self: """Modifier to set the transaction to READER.""" return self._clone(mode=_READER) @property - def allow_async(self): + def allow_async(self) -> Self: """Modifier to allow async operations Allows async operations if asynchronous session is already @@ -952,36 +1101,73 @@ class _TransactionContextManager(object): return self._clone(allow_async=True) @property - def independent(self): + def independent(self) -> Self: """Modifier to start a transaction independent from any enclosing.""" return self._clone(independent=True) @property - def savepoint(self): + def savepoint(self) -> Self: """Modifier to start a SAVEPOINT if a transaction already exists.""" return self._clone(savepoint=True) @property - def connection(self): - """Modifier to return a core Connection object instead of Session.""" - return self._clone(connection=True) - - @property - def async_(self): + def async_(self) -> Self: """Modifier to set a READER operation to ASYNC_READER.""" if self._mode is _WRITER: raise TypeError("Setting async on a WRITER makes no sense") return self._clone(mode=_ASYNC_READER) - def using(self, context): - """Provide a context manager block that will use the given context.""" - return self._transaction_scope(context) + def _clone(self, **kw) -> Any: + default_kw = { + "independent": self._independent, + "mode": self._mode, + "connection": self._connection, + } + default_kw.update(kw) + return self.__class__(root=self._root, **default_kw) + def _transaction_scope_impl(self, transaction_contexts_concurrent) -> Any: + + new_transaction = self._independent + + current = restore = getattr( + transaction_contexts_concurrent, "current", None + ) + + use_factory = self._factory + global_factory = None + + if self._replace_global_factory: + use_factory = global_factory = self._replace_global_factory + elif current is not None and current.global_factory: + global_factory = current.global_factory + + if self._root._is_global_manager: + use_factory = global_factory + + if current is not None and ( + new_transaction or current.factory is not use_factory + ): + current = None + + if current is None: + current = transaction_contexts_concurrent.current = ( + self._create_transaction_context( + use_factory, global_factory=global_factory + ) + ) + + return current, restore + + +class _TransactionContextManager( + _AbstractTransactionContextManager[_SyncEngineSessionT] +): def __call__(self, fn): """Decorate a function.""" argspec = inspect.getfullargspec(fn) - if argspec.args[0] == 'self' or argspec.args[0] == 'cls': + if argspec.args[0] == "self" or argspec.args[0] == "cls": context_index = 1 else: context_index = 0 @@ -998,43 +1184,15 @@ class _TransactionContextManager(object): return wrapper - def _clone(self, **kw): - default_kw = { - "independent": self._independent, - "mode": self._mode, - "connection": self._connection - } - default_kw.update(kw) - return _TransactionContextManager(root=self._root, **default_kw) - @contextlib.contextmanager - def _transaction_scope(self, context): - new_transaction = self._independent - transaction_contexts_by_thread = \ - _transaction_contexts_by_thread(context) + def _transaction_scope(self, context) -> Any: + transaction_contexts_by_thread = _transaction_contexts_by_thread( + context + ) - current = restore = getattr( - transaction_contexts_by_thread, "current", None) - - use_factory = self._factory - global_factory = None - - if self._replace_global_factory: - use_factory = global_factory = self._replace_global_factory - elif current is not None and current.global_factory: - global_factory = current.global_factory - - if self._root._is_global_manager: - use_factory = global_factory - - if current is not None and ( - new_transaction or current.factory is not use_factory - ): - current = None - - if current is None: - current = transaction_contexts_by_thread.current = \ - _TransactionContext(use_factory, global_factory=global_factory) + current, restore = self._transaction_scope_impl( + transaction_contexts_by_thread + ) try: if self._mode is not None: @@ -1043,7 +1201,8 @@ class _TransactionContextManager(object): connection=self._connection, savepoint=self._savepoint, allow_async=self._allow_async, - context=context) as resource: + context=context, + ) as resource: yield resource else: yield @@ -1053,18 +1212,31 @@ class _TransactionContextManager(object): elif current is not restore: transaction_contexts_by_thread.current = restore + @property + def connection(self) -> _TransactionContextManager[Connection]: + """Modifier to return a core Connection object instead of Session.""" + return self._clone(connection=True) -def _context_descriptor(attr=None): + def using(self, context) -> ContextManager[_SyncEngineSessionT]: + """Provide a context manager block that will use the given context.""" + return self._transaction_scope(context) + + +def _context_descriptor( + attr=None, transaction_ctx_getter=operator.attrgetter("transaction_ctx") +): getter = operator.attrgetter(attr) def _property_for_context(context): try: - transaction_context = context.transaction_ctx + transaction_context = transaction_ctx_getter(context) except exception.NoEngineContextEstablished: raise exception.NoEngineContextEstablished( "No TransactionContext is established for " "this %s object within the current thread; " - "the %r attribute is unavailable." + "the %r attribute is unavailable. " + "Ensure this Context class participates in the " + "transaction_context_provider() class decorator." % (context, attr) ) else: @@ -1075,6 +1247,7 @@ def _context_descriptor(attr=None): "it has not been established for this context." % attr ) return result + return property(_property_for_context) @@ -1086,16 +1259,19 @@ def _transaction_ctx_for_context(context): raise exception.NoEngineContextEstablished( "No TransactionContext is established for " "this %s object within the current thread. " - % context + "Ensure this Context class participates in the " + "transaction_context_provider() class decorator." % context ) def _transaction_contexts_by_thread(context): transaction_contexts_by_thread = getattr( - context, '_enginefacade_context', None) + context, "_enginefacade_context", None + ) if transaction_contexts_by_thread is None: - transaction_contexts_by_thread = \ - context._enginefacade_context = _TransactionContextTLocal() + transaction_contexts_by_thread = context._enginefacade_context = ( + _TransactionContextTLocal() + ) return transaction_contexts_by_thread @@ -1103,26 +1279,23 @@ def _transaction_contexts_by_thread(context): def transaction_context_provider(klass): """Decorate a class with ``session`` and ``connection`` attributes.""" - setattr( - klass, - 'transaction_ctx', - property(_transaction_ctx_for_context)) + setattr(klass, "transaction_ctx", property(_transaction_ctx_for_context)) # Graft transaction context attributes as context properties - for attr in ('session', 'connection', 'transaction'): + for attr in ("session", "connection", "transaction"): setattr(klass, attr, _context_descriptor(attr)) return klass -_context_manager = _TransactionContextManager(_is_global_manager=True) +_context_manager: _TransactionContextManager[Session] = ( + _TransactionContextManager(_is_global_manager=True) +) """default context manager.""" def transaction_context(): - """Construct a local transaction context. - - """ + """Construct a local transaction context.""" return _TransactionContextManager() @@ -1237,14 +1410,23 @@ class LegacyEngineFacade(object): True) """ - def __init__(self, sql_connection, slave_connection=None, - sqlite_fk=False, expire_on_commit=False, _conf=None, - _factory=None, **kwargs): + + def __init__( + self, + sql_connection, + slave_connection=None, + sqlite_fk=False, + expire_on_commit=False, + _conf=None, + _factory=None, + **kwargs, + ): warnings.warn( "EngineFacade is deprecated; please use " "oslo_db.sqlalchemy.enginefacade", warning.OsloDBDeprecationWarning, - stacklevel=2) + stacklevel=2, + ) if _factory: self._factory = _factory @@ -1254,13 +1436,15 @@ class LegacyEngineFacade(object): self._factory.configure( sqlite_fk=sqlite_fk, expire_on_commit=expire_on_commit, - **kwargs + **kwargs, ) # make sure passed-in urls are favored over that # of config self._factory._start( - _conf, connection=sql_connection, - slave_connection=slave_connection) + _conf, + connection=sql_connection, + slave_connection=slave_connection, + ) def _check_factory_started(self): if not self._factory._started: @@ -1317,8 +1501,7 @@ class LegacyEngineFacade(object): return self._factory._writer_maker @classmethod - def from_config(cls, conf, - sqlite_fk=False, expire_on_commit=False): + def from_config(cls, conf, sqlite_fk=False, expire_on_commit=False): """Initialize EngineFacade using oslo.config config instance options. :param conf: oslo.config config instance @@ -1335,4 +1518,6 @@ class LegacyEngineFacade(object): return cls( None, sqlite_fk=sqlite_fk, - expire_on_commit=expire_on_commit, _conf=conf) + expire_on_commit=expire_on_commit, + _conf=conf, + ) diff --git a/oslo_db/sqlalchemy/engines.py b/oslo_db/sqlalchemy/engines.py index dc211fc1..65a8f9c9 100644 --- a/oslo_db/sqlalchemy/engines.py +++ b/oslo_db/sqlalchemy/engines.py @@ -87,17 +87,18 @@ def _connect_ping_listener(connection, branch): # new connections assuming they are good now. # run the select again to re-validate the Connection. LOG.exception( - 'Database connection was found disconnected; reconnecting') + "Database connection was found disconnected; reconnecting" + ) # TODO(ralonsoh): drop this attr check once SQLAlchemy minimum version # is 2.0. - if hasattr(connection, 'rollback'): + if hasattr(connection, "rollback"): connection.rollback() connection.scalar(select(1)) finally: connection.should_close_with_result = save_should_close_with_result # TODO(ralonsoh): drop this attr check once SQLAlchemy minimum version # is 2.0. - if hasattr(connection, 'rollback'): + if hasattr(connection, "rollback"): connection.rollback() @@ -106,7 +107,8 @@ def _connect_ping_listener(connection, branch): # and wrap out a parameter that is deprecated if compat.sqla_2: _connect_ping_listener = functools.partial( - _connect_ping_listener, branch=False) + _connect_ping_listener, branch=False + ) def _setup_logging(connection_debug=0): @@ -119,7 +121,7 @@ def _setup_logging(connection_debug=0): 100=Processed only messages with DEBUG level """ if connection_debug >= 0: - logger = logging.getLogger('sqlalchemy.engine') + logger = logging.getLogger("sqlalchemy.engine") if connection_debug == 100: logger.setLevel(logging.DEBUG) elif connection_debug >= 50: @@ -136,30 +138,78 @@ def _vet_url(url): "and will make use of a default driver. " "A full dbname+drivername:// protocol is recommended. " "For MySQL, it is strongly recommended that mysql+pymysql:// " - "be specified for maximum service compatibility", url + "be specified for maximum service compatibility", + url, ) else: LOG.warning( "URL %r does not contain a '+drivername' portion, " "and will make use of a default driver. " - "A full dbname+drivername:// protocol is recommended.", url + "A full dbname+drivername:// protocol is recommended.", + url, ) +def _create_engine(url, **engine_args): + engine = sqlalchemy.create_engine(url, **engine_args) + return engine, engine + + +def _test_connection(engine, max_retries, retry_interval, _close=True): + if max_retries == -1: + attempts = itertools.count() + else: + attempts = range(max_retries) + # See: http://legacy.python.org/dev/peps/pep-3110/#semantic-changes for + # why we are not using 'de' directly (it can be removed from the local + # scope). + de_ref = conn = None + for attempt in attempts: + try: + conn = engine.connect() + except exception.DBConnectionError as de: + msg = "SQL connection failed. %s attempts left." + LOG.warning(msg, max_retries - attempt) + time.sleep(retry_interval) + de_ref = de + else: + if _close: + conn.close() + break + else: + if de_ref is not None: + raise de_ref + + return conn + + @debtcollector.renames.renamed_kwarg( - 'idle_timeout', - 'connection_recycle_time', + "idle_timeout", + "connection_recycle_time", replace=True, ) -def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, - mysql_wsrep_sync_wait=None, - connection_recycle_time=3600, - connection_debug=0, max_pool_size=None, max_overflow=None, - pool_timeout=None, sqlite_synchronous=True, - connection_trace=False, max_retries=10, retry_interval=10, - thread_checkin=True, logging_name=None, - json_serializer=None, - json_deserializer=None, connection_parameters=None): +def create_engine( + sql_connection, + sqlite_fk=False, + mysql_sql_mode=None, + mysql_wsrep_sync_wait=None, + connection_recycle_time=3600, + connection_debug=0, + max_pool_size=None, + max_overflow=None, + pool_timeout=None, + sqlite_synchronous=True, + connection_trace=False, + max_retries=10, + retry_interval=10, + thread_checkin=True, + logging_name=None, + json_serializer=None, + json_deserializer=None, + connection_parameters=None, + _engine_target=_create_engine, + _test_connection=_test_connection, +): """Return a new SQLAlchemy engine.""" url = utils.make_url(sql_connection) @@ -172,56 +222,58 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, _native_pre_ping = compat.native_pre_ping_event_support engine_args = { - 'pool_recycle': connection_recycle_time, - 'pool_pre_ping': _native_pre_ping, - 'connect_args': {}, - 'logging_name': logging_name + "pool_recycle": connection_recycle_time, + "pool_pre_ping": _native_pre_ping, + "connect_args": {}, + "logging_name": logging_name, } _setup_logging(connection_debug) _init_connection_args( - url, engine_args, + url, + engine_args, dict( max_pool_size=max_pool_size, max_overflow=max_overflow, pool_timeout=pool_timeout, json_serializer=json_serializer, json_deserializer=json_deserializer, - ) + ), ) - engine = sqlalchemy.create_engine(url, **engine_args) + engine, engine_event_target = _engine_target(url, **engine_args) _init_events( - engine, + engine_event_target, mysql_sql_mode=mysql_sql_mode, mysql_wsrep_sync_wait=mysql_wsrep_sync_wait, sqlite_synchronous=sqlite_synchronous, sqlite_fk=sqlite_fk, thread_checkin=thread_checkin, - connection_trace=connection_trace + connection_trace=connection_trace, ) # register alternate exception handler - exc_filters.register_engine(engine) + exc_filters.register_engine(engine_event_target) if not _native_pre_ping: # register engine connect handler. - event.listen(engine, "engine_connect", _connect_ping_listener) + event.listen( + engine_event_target, "engine_connect", _connect_ping_listener + ) # initial connect + test # NOTE(viktors): the current implementation of _test_connection() # does nothing, if max_retries == 0, so we can skip it if max_retries: - test_conn = _test_connection(engine, max_retries, retry_interval) - test_conn.close() + _test_connection(engine_event_target, max_retries, retry_interval) return engine -@utils.dispatch_for_dialect('*', multiple=True) +@utils.dispatch_for_dialect("*", multiple=True) def _init_connection_args(url, engine_args, kw): # (zzzeek) kw is passed by reference rather than as **kw so that the @@ -236,11 +288,11 @@ def _init_connection_args(url, engine_args, kw): pool_class = url.get_dialect().get_pool_class(url) if issubclass(pool_class, pool.QueuePool): if max_pool_size is not None: - engine_args['pool_size'] = max_pool_size + engine_args["pool_size"] = max_pool_size if max_overflow is not None: - engine_args['max_overflow'] = max_overflow + engine_args["max_overflow"] = max_overflow if pool_timeout is not None: - engine_args['pool_timeout'] = pool_timeout + engine_args["pool_timeout"] = pool_timeout @_init_connection_args.dispatch_for("sqlite") @@ -250,7 +302,7 @@ def _init_connection_args(url, engine_args, kw): # singletonthreadpool is used for :memory: connections; # replace it with StaticPool. engine_args["poolclass"] = pool.StaticPool - engine_args['connect_args']['check_same_thread'] = False + engine_args["connect_args"]["check_same_thread"] = False elif issubclass(pool_class, pool.QueuePool): # SQLAlchemy 2.0 uses QueuePool for sqlite file DBs; put NullPool # back to avoid compatibility issues @@ -263,19 +315,19 @@ def _init_connection_args(url, engine_args, kw): @_init_connection_args.dispatch_for("postgresql") def _init_connection_args(url, engine_args, kw): - if 'client_encoding' not in url.query: + if "client_encoding" not in url.query: # Set encoding using engine_args instead of connect_args since # it's supported for PostgreSQL 8.*. More details at: # http://docs.sqlalchemy.org/en/rel_0_9/dialects/postgresql.html - engine_args['client_encoding'] = 'utf8' - engine_args['json_serializer'] = kw.get('json_serializer') - engine_args['json_deserializer'] = kw.get('json_deserializer') + engine_args["client_encoding"] = "utf8" + engine_args["json_serializer"] = kw.get("json_serializer") + engine_args["json_deserializer"] = kw.get("json_deserializer") @_init_connection_args.dispatch_for("mysql") def _init_connection_args(url, engine_args, kw): - if 'charset' not in url.query: - engine_args['connect_args']['charset'] = 'utf8' + if "charset" not in url.query: + engine_args["connect_args"]["charset"] = "utf8" @_init_connection_args.dispatch_for("mysql+mysqlconnector") @@ -283,8 +335,8 @@ def _init_connection_args(url, engine_args, kw): # mysqlconnector engine (<1.0) incorrectly defaults to # raise_on_warnings=True # https://bitbucket.org/zzzeek/sqlalchemy/issue/2515 - if 'raise_on_warnings' not in url.query: - engine_args['connect_args']['raise_on_warnings'] = False + if "raise_on_warnings" not in url.query: + engine_args["connect_args"]["raise_on_warnings"] = False @_init_connection_args.dispatch_for("mysql+mysqldb") @@ -292,11 +344,11 @@ def _init_connection_args(url, engine_args, kw): # Those drivers require use_unicode=0 to avoid performance drop due # to internal usage of Python unicode objects in the driver # http://docs.sqlalchemy.org/en/rel_0_9/dialects/mysql.html - if 'use_unicode' not in url.query: - engine_args['connect_args']['use_unicode'] = 1 + if "use_unicode" not in url.query: + engine_args["connect_args"]["use_unicode"] = 1 -@utils.dispatch_for_dialect('*', multiple=True) +@utils.dispatch_for_dialect("*", multiple=True) def _init_events(engine, thread_checkin=True, connection_trace=False, **kw): """Set up event listeners for all database backends.""" @@ -306,15 +358,17 @@ def _init_events(engine, thread_checkin=True, connection_trace=False, **kw): _add_trace_comments(engine) if thread_checkin: - sqlalchemy.event.listen(engine, 'checkin', _thread_yield) + sqlalchemy.event.listen(engine, "checkin", _thread_yield) @_init_events.dispatch_for("mysql") def _init_events( - engine, mysql_sql_mode=None, mysql_wsrep_sync_wait=None, **kw): + engine, mysql_sql_mode=None, mysql_wsrep_sync_wait=None, **kw +): """Set up event listeners for MySQL.""" if mysql_sql_mode is not None or mysql_wsrep_sync_wait is not None: + @sqlalchemy.event.listens_for(engine, "connect") def _set_session_variables(dbapi_con, connection_rec): cursor = dbapi_con.cursor() @@ -322,8 +376,7 @@ def _init_events( cursor.execute("SET SESSION sql_mode = %s", [mysql_sql_mode]) if mysql_wsrep_sync_wait is not None: cursor.execute( - "SET SESSION wsrep_sync_wait = %s", - [mysql_wsrep_sync_wait] + "SET SESSION wsrep_sync_wait = %s", [mysql_wsrep_sync_wait] ) @sqlalchemy.event.listens_for(engine, "first_connect") @@ -336,16 +389,19 @@ def _init_events( realmode = cursor.fetchone() if realmode is None: - LOG.warning('Unable to detect effective SQL mode') + LOG.warning("Unable to detect effective SQL mode") else: realmode = realmode[1] - LOG.debug('MySQL server mode set to %s', realmode) - if 'TRADITIONAL' not in realmode.upper() and \ - 'STRICT_ALL_TABLES' not in realmode.upper(): + LOG.debug("MySQL server mode set to %s", realmode) + if ( + "TRADITIONAL" not in realmode.upper() and + "STRICT_ALL_TABLES" not in realmode.upper() + ): LOG.warning( "MySQL SQL mode is '%s', " "consider enabling TRADITIONAL or STRICT_ALL_TABLES", - realmode) + realmode, + ) @_init_events.dispatch_for("sqlite") @@ -365,7 +421,7 @@ def _init_events(engine, sqlite_synchronous=True, sqlite_fk=False, **kw): def _sqlite_connect_events(dbapi_con, con_record): # Add REGEXP functionality on SQLite connections - dbapi_con.create_function('regexp', 2, regexp) + dbapi_con.create_function("regexp", 2, regexp) if not sqlite_synchronous: # Switch sqlite connections to non-synchronous mode @@ -380,43 +436,21 @@ def _init_events(engine, sqlite_synchronous=True, sqlite_fk=False, **kw): if sqlite_fk: # Ensures that the foreign key constraints are enforced in SQLite. - dbapi_con.execute('pragma foreign_keys=ON') + dbapi_con.execute("pragma foreign_keys=ON") @sqlalchemy.event.listens_for(engine, "begin") def _sqlite_emit_begin(conn): # emit our own BEGIN, checking for existing # transactional state - if 'in_transaction' not in conn.info: + if "in_transaction" not in conn.info: conn.execute(sqlalchemy.text("BEGIN")) - conn.info['in_transaction'] = True + conn.info["in_transaction"] = True @sqlalchemy.event.listens_for(engine, "rollback") @sqlalchemy.event.listens_for(engine, "commit") def _sqlite_end_transaction(conn): # remove transactional marker - conn.info.pop('in_transaction', None) - - -def _test_connection(engine, max_retries, retry_interval): - if max_retries == -1: - attempts = itertools.count() - else: - attempts = range(max_retries) - # See: http://legacy.python.org/dev/peps/pep-3110/#semantic-changes for - # why we are not using 'de' directly (it can be removed from the local - # scope). - de_ref = None - for attempt in attempts: - try: - return engine.connect() - except exception.DBConnectionError as de: - msg = 'SQL connection failed. %s attempts left.' - LOG.warning(msg, max_retries - attempt) - time.sleep(retry_interval) - de_ref = de - else: - if de_ref is not None: - raise de_ref + conn.info.pop("in_transaction", None) def _add_process_guards(engine): @@ -429,21 +463,22 @@ def _add_process_guards(engine): @sqlalchemy.event.listens_for(engine, "connect") def connect(dbapi_connection, connection_record): - connection_record.info['pid'] = os.getpid() + connection_record.info["pid"] = os.getpid() @sqlalchemy.event.listens_for(engine, "checkout") def checkout(dbapi_connection, connection_record, connection_proxy): pid = os.getpid() - if connection_record.info['pid'] != pid: + if connection_record.info["pid"] != pid: LOG.debug( "Parent process %(orig)s forked (%(newproc)s) with an open " "database connection, " "which is being discarded and recreated.", - {"newproc": pid, "orig": connection_record.info['pid']}) + {"newproc": pid, "orig": connection_record.info["pid"]}, + ) raise exc.DisconnectionError( "Connection record belongs to pid %s, " - "attempting to check out in pid %s" % - (connection_record.info['pid'], pid) + "attempting to check out in pid %s" + % (connection_record.info["pid"], pid) ) @@ -457,20 +492,26 @@ def _add_trace_comments(engine): import os import sys import traceback - target_paths = set([ - os.path.dirname(sys.modules['oslo_db'].__file__), - os.path.dirname(sys.modules['sqlalchemy'].__file__) - ]) + + target_paths = set( + [ + os.path.dirname(sys.modules["oslo_db"].__file__), + os.path.dirname(sys.modules["sqlalchemy"].__file__), + ] + ) try: - skip_paths = set([ - os.path.dirname(sys.modules['oslo_db.tests'].__file__), - ]) + skip_paths = set( + [ + os.path.dirname(sys.modules["oslo_db.tests"].__file__), + ] + ) except KeyError: skip_paths = set() @sqlalchemy.event.listens_for(engine, "before_cursor_execute", retval=True) - def before_cursor_execute(conn, cursor, statement, parameters, context, - executemany): + def before_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): # NOTE(zzzeek) - if different steps per DB dialect are desirable # here, switch out on engine.name for now. @@ -491,12 +532,9 @@ def _add_trace_comments(engine): if our_line: trace = "; ".join( - "File: %s (%s) %s" % ( - line[0], line[1], line[2] - ) + "File: %s (%s) %s" % (line[0], line[1], line[2]) # include three lines of context. - for line in stack[our_line - 3:our_line] - + for line in stack[our_line - 3: our_line] ) statement = "%s -- %s" % (statement, trace) diff --git a/oslo_db/tests/base.py b/oslo_db/tests/base.py index a74ef827..10d2420c 100644 --- a/oslo_db/tests/base.py +++ b/oslo_db/tests/base.py @@ -10,6 +10,8 @@ # License for the specific language governing permissions and limitations # under the License. +import unittest + from oslotest import base from oslo_db.tests import fixtures @@ -23,3 +25,11 @@ class BaseTestCase(base.BaseTestCase): super().setUp() self.warning_fixture = self.useFixture(fixtures.WarningsFixture()) + + +class BaseAsyncioCase(unittest.IsolatedAsyncioTestCase): + def run(self, result=None): + # work around stestr sending a result object that's not a + # unittest.Result + result.addDuration = lambda test, elapsed: None + return super().run(result=result) diff --git a/oslo_db/tests/sqlalchemy/test_asyncio_facade.py b/oslo_db/tests/sqlalchemy/test_asyncio_facade.py new file mode 100644 index 00000000..59a0a2f3 --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_asyncio_facade.py @@ -0,0 +1,67 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_context import context as oslo_context +from oslo_db.sqlalchemy import asyncio_facade +from oslo_db.tests import base as test_base +from sqlalchemy import text + +asyncio_facade.async_transaction_context_provider(oslo_context.RequestContext) + + +class AsyncioFacadeTest(test_base.BaseAsyncioCase): + + def setUp(self): + super().setUp() + asyncio_facade.configure( + connection="sqlite+aiosqlite://", + ) + + def tearDown(self): + asyncio_facade._async_context_manager._root_factory = ( + asyncio_facade._async_context_manager._create_root_factory() + ) + super().tearDown() + + async def test_contextmanager_session(self): + context = oslo_context.RequestContext() + async with asyncio_facade.reader.using(context) as session: + result = await session.execute(text("select 1")) + self.assertEqual(result.all(), [(1,)]) + + async def test_contextmanager_connection(self): + context = oslo_context.RequestContext() + async with asyncio_facade.reader.connection.using( + context + ) as connection: + result = await connection.execute(text("select 1")) + self.assertEqual(result.all(), [(1,)]) + + async def test_callable_session(self): + context = oslo_context.RequestContext() + + @asyncio_facade.reader + async def select_one(context): + result = await context.async_session.execute(text("select 1")) + self.assertEqual(result.all(), [(1,)]) + + await select_one(context) + + async def test_callable_connection(self): + context = oslo_context.RequestContext() + + @asyncio_facade.reader.connection + async def select_one(context): + result = await context.async_connection.execute(text("select 1")) + self.assertEqual(result.all(), [(1,)]) + + await select_one(context) diff --git a/oslo_db/tests/sqlalchemy/test_exc_filters.py b/oslo_db/tests/sqlalchemy/test_exc_filters.py index d5bd5ce1..d062ca32 100644 --- a/oslo_db/tests/sqlalchemy/test_exc_filters.py +++ b/oslo_db/tests/sqlalchemy/test_exc_filters.py @@ -1552,7 +1552,8 @@ class TestDBConnectRetry(TestsExceptionFilter): with self._dbapi_fixture(dialect_name): with mock.patch.object(engine.dialect, "connect", cant_connect): - return engines._test_connection(engine, retries, .01) + return engines._test_connection( + engine, retries, .01, _close=False) def test_connect_no_retries(self): conn = self._run_test( diff --git a/test-requirements.txt b/test-requirements.txt index 6aae3ef6..5df67652 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -9,6 +9,7 @@ stestr>=2.0.0 # Apache-2.0 testtools>=2.2.0 # MIT bandit>=1.7.0,<1.8.0 # Apache-2.0 pifpaf>=0.10.0 # Apache-2.0 +aiosqlite>=0.20.0 # MIT License PyMySQL>=0.7.6 # MIT License psycopg2>=2.8.0 # LGPL/ZPL pre-commit>=2.6.0 # MIT