diff --git a/mistral/services/scheduler.py b/mistral/services/scheduler.py index 25e21ee1..ec60742f 100644 --- a/mistral/services/scheduler.py +++ b/mistral/services/scheduler.py @@ -131,7 +131,7 @@ class CallScheduler(periodic_task.PeriodicTasks): for call in calls_to_make: LOG.debug('Processing next delayed call: %s', call) - context.set_ctx(context.MistralContext(call.auth_context)) + target_auth_context = copy.deepcopy(call.auth_context) if call.factory_method_path: factory = importutils.import_class( @@ -157,19 +157,30 @@ class CallScheduler(periodic_task.PeriodicTasks): method_args[arg_name] = deserialized - delayed_calls.append((target_method, method_args)) + delayed_calls.append( + (target_auth_context, target_method, method_args) + ) + + for (target_auth_context, target_method, method_args) in delayed_calls: - for (target_method, method_args) in delayed_calls: # Transaction is needed here because some of the # target_method can use the DB with db_api.transaction(): try: + # Set the correct context for the method. + context.set_ctx( + context.MistralContext(target_auth_context) + ) + # Call the method. target_method(**method_args) except Exception as e: LOG.error( "Delayed call failed [exception=%s]", e ) + finally: + # Remove context. + context.set_ctx(None) with db_api.transaction(): for call in calls_to_make: diff --git a/mistral/tests/unit/services/test_scheduler.py b/mistral/tests/unit/services/test_scheduler.py index 0c8b9c16..af33cd16 100644 --- a/mistral/tests/unit/services/test_scheduler.py +++ b/mistral/tests/unit/services/test_scheduler.py @@ -17,6 +17,7 @@ import datetime import eventlet import mock +from mistral import context as auth_context from mistral.db.v2 import api as db_api from mistral import exceptions as exc from mistral.services import scheduler @@ -24,12 +25,21 @@ from mistral.tests.unit import base from mistral.workflow import utils as wf_utils -FACTORY_METHOD_PATH = ('mistral.tests.unit.services.test_scheduler.' - 'factory_method') -TARGET_METHOD_PATH = ('mistral.tests.unit.services.test_scheduler.' - 'target_method') -FAILED_TO_SEND_TARGET_PATH = ('mistral.tests.unit.services.test_scheduler.' - 'failed_to_send') +FACTORY_METHOD_PATH = ( + 'mistral.tests.unit.services.test_scheduler.factory_method' +) +TARGET_METHOD_PATH = ( + 'mistral.tests.unit.services.test_scheduler.target_method' +) +CHECK_CONTEXT_METHOD_PATH = ( + 'mistral.tests.unit.services.test_scheduler.target_check_context_method' +) +COMPARE_CONTEXT_METHOD_PATH = ( + 'mistral.tests.unit.services.test_scheduler.compare_context_values' +) +FAILED_TO_SEND_TARGET_PATH = ( + 'mistral.tests.unit.services.test_scheduler.failed_to_send' +) DELAY = 1.5 WAIT = DELAY * 3 @@ -47,6 +57,15 @@ def target_method(): pass +def compare_context_values(expected, actual): + pass + + +def target_check_context_method(expected_project_id): + actual_project_id = auth_context.ctx()._BaseContext__values['project_id'] + compare_context_values(expected_project_id, actual_project_id) + + def failed_to_send(): raise exc.EngineException("Test") @@ -128,6 +147,50 @@ class SchedulerServiceTest(base.DbTestCase): self.assertEqual(0, len(calls)) + @mock.patch(COMPARE_CONTEXT_METHOD_PATH) + def test_scheduler_call_target_method_with_correct_auth(self, method): + default_context = base.get_context(default=True) + auth_context.set_ctx(default_context) + default_project_id = ( + default_context._BaseContext__values['project_id'] + ) + method_args1 = {'expected_project_id': default_project_id} + + scheduler.schedule_call( + None, + CHECK_CONTEXT_METHOD_PATH, + DELAY, + **method_args1 + ) + + second_context = base.get_context(default=False) + auth_context.set_ctx(second_context) + second_project_id = ( + second_context._BaseContext__values['project_id'] + ) + method_args2 = {'expected_project_id': second_project_id} + + scheduler.schedule_call( + None, + CHECK_CONTEXT_METHOD_PATH, + DELAY, + **method_args2 + ) + + eventlet.sleep(WAIT) + + method.assert_any_call( + default_project_id, + default_project_id + ) + + method.assert_any_call( + second_project_id, + second_project_id + ) + + self.assertNotEqual(default_project_id, second_project_id) + @mock.patch(FACTORY_METHOD_PATH) def test_scheduler_with_serializer(self, factory): target_method = 'run_something'