diff --git a/oslo_db/sqlalchemy/enginefacade.py b/oslo_db/sqlalchemy/enginefacade.py index 2b3fa73..d80ff05 100644 --- a/oslo_db/sqlalchemy/enginefacade.py +++ b/oslo_db/sqlalchemy/enginefacade.py @@ -13,6 +13,7 @@ import contextlib import functools +import inspect import operator import threading import warnings @@ -700,10 +701,15 @@ class _TransactionContextManager(object): def __call__(self, fn): """Decorate a function.""" + argspec = inspect.getargspec(fn) + if argspec.args[0] == 'self' or argspec.args[0] == 'cls': + context_index = 1 + else: + context_index = 0 @functools.wraps(fn) def wrapper(*args, **kwargs): - context = args[0] + context = args[context_index] with self._transaction_scope(context): return fn(*args, **kwargs) diff --git a/oslo_db/tests/sqlalchemy/test_enginefacade.py b/oslo_db/tests/sqlalchemy/test_enginefacade.py index 97229f0..7f77c28 100644 --- a/oslo_db/tests/sqlalchemy/test_enginefacade.py +++ b/oslo_db/tests/sqlalchemy/test_enginefacade.py @@ -1004,6 +1004,19 @@ class MockFacadeTest(oslo_test_base.BaseTestCase): getattr, context, 'transaction_ctx' ) + def test_context_found_for_bound_method(self): + context = oslo_context.RequestContext() + + @enginefacade.reader + def go(self, context): + context.session.execute("test") + go(self, context) + + with self._assert_engines() as engines: + with self._assert_makers(engines) as makers: + with self._assert_reader_session(makers) as session: + session.execute("test") + class SynchronousReaderWSlaveMockFacadeTest(MockFacadeTest): synchronous_reader = True