diff --git a/ironic_python_agent/agent.py b/ironic_python_agent/agent.py index 079a323c9..5bf2eeca9 100644 --- a/ironic_python_agent/agent.py +++ b/ironic_python_agent/agent.py @@ -121,6 +121,9 @@ class IronicPythonAgentHeartbeater(threading.Thread): self.error_delay = min(self.error_delay * self.backoff_factor, self.max_delay) + def force_heartbeat(self): + os.write(self.writer, 'b') + def stop(self): """Stop the heartbeat thread.""" if self.writer is not None: @@ -140,6 +143,7 @@ class IronicPythonAgent(base.ExecuteCommandMixin): namespace='ironic_python_agent.extensions', invoke_on_load=True, propagate_map_exceptions=True, + invoke_kwds={'agent': self}, ) self.api_url = api_url self.driver_name = driver_name @@ -262,6 +266,9 @@ class IronicPythonAgent(base.ExecuteCommandMixin): raise errors.RequestedObjectNotFoundError('Command Result', result_id) + def force_heartbeat(self): + self.heartbeater.force_heartbeat() + def run(self): """Run the Ironic Python Agent.""" # Get the UUID so we can heartbeat to Ironic. Raises LookupNodeError diff --git a/ironic_python_agent/extensions/base.py b/ironic_python_agent/extensions/base.py index 9266cbbb3..14021906f 100644 --- a/ironic_python_agent/extensions/base.py +++ b/ironic_python_agent/extensions/base.py @@ -67,8 +67,10 @@ class AsyncCommandResult(BaseCommandResult): :param execute_method: a callable to be executed asynchronously """ - def __init__(self, command_name, command_params, execute_method): + def __init__(self, command_name, command_params, execute_method, + agent=None): super(AsyncCommandResult, self).__init__(command_name, command_params) + self.agent = agent self.execute_method = execute_method self.command_state_lock = threading.Lock() @@ -98,7 +100,6 @@ class AsyncCommandResult(BaseCommandResult): with self.command_state_lock: self.command_result = result self.command_status = AgentCommandStatus.SUCCEEDED - except Exception as e: if not isinstance(e, errors.RESTError): e = errors.CommandExecutionError(str(e)) @@ -106,11 +107,15 @@ class AsyncCommandResult(BaseCommandResult): with self.command_state_lock: self.command_error = e self.command_status = AgentCommandStatus.FAILED + finally: + if self.agent: + self.agent.force_heartbeat() class BaseAgentExtension(object): - def __init__(self): + def __init__(self, agent=None): super(BaseAgentExtension, self).__init__() + self.agent = agent self.log = log.getLogger(__name__) self.command_map = dict( (v.command_name, v) @@ -207,7 +212,8 @@ def async_command(command_name, validator=None): return AsyncCommandResult(command_name, command_params, - bound_func).start() + bound_func, + agent=self.agent).start() return wrapper return async_decorator diff --git a/ironic_python_agent/extensions/standby.py b/ironic_python_agent/extensions/standby.py index 6d08672bb..eb8faa8c2 100644 --- a/ironic_python_agent/extensions/standby.py +++ b/ironic_python_agent/extensions/standby.py @@ -172,8 +172,8 @@ def _validate_image_info(ext, image_info=None, **kwargs): class StandbyExtension(base.BaseAgentExtension): - def __init__(self): - super(StandbyExtension, self).__init__() + def __init__(self, agent=None): + super(StandbyExtension, self).__init__(agent=agent) self.cached_image_id = None diff --git a/ironic_python_agent/tests/agent.py b/ironic_python_agent/tests/agent.py index b2f46d996..747cfe396 100644 --- a/ironic_python_agent/tests/agent.py +++ b/ironic_python_agent/tests/agent.py @@ -49,8 +49,7 @@ def foo_execute(*args, **kwargs): class FakeExtension(base.BaseAgentExtension): - def __init__(self): - super(FakeExtension, self).__init__() + pass class TestHeartbeater(test_base.BaseTestCase): diff --git a/ironic_python_agent/tests/extensions/base.py b/ironic_python_agent/tests/extensions/base.py index 4be19d94f..30ce34a07 100644 --- a/ironic_python_agent/tests/extensions/base.py +++ b/ironic_python_agent/tests/extensions/base.py @@ -120,7 +120,9 @@ class TestExecuteCommandMixin(test_base.BaseTestCase): class TestExtensionDecorators(test_base.BaseTestCase): def setUp(self): super(TestExtensionDecorators, self).setUp() - self.extension = FakeExtension() + self.agent = FakeAgent() + self.agent.force_heartbeat = mock.Mock() + self.extension = FakeExtension(agent=self.agent) def test_async_command_success(self): result = self.extension.execute('fake_async_command', param='v1') @@ -132,12 +134,27 @@ class TestExtensionDecorators(test_base.BaseTestCase): result.command_status) self.assertEqual(None, result.command_error) self.assertEqual('v1', result.command_result) + self.agent.force_heartbeat.assert_called_once_with() + + def test_async_command_success_without_agent(self): + extension = FakeExtension(agent=None) + result = extension.execute('fake_async_command', param='v1') + self.assertIsInstance(result, base.AsyncCommandResult) + result.join() + self.assertEqual('fake_async_command', result.command_name) + self.assertEqual({'param': 'v1'}, result.command_params) + self.assertEqual(base.AgentCommandStatus.SUCCEEDED, + result.command_status) + self.assertEqual(None, result.command_error) + self.assertEqual('v1', result.command_result) def test_async_command_validation_failure(self): self.assertRaises(errors.InvalidCommandParamsError, self.extension.execute, 'fake_async_command', is_valid=False) + # validation is synchronous, no need to force a heartbeat + self.assertEqual(0, self.agent.force_heartbeat.call_count) def test_async_command_execution_failure(self): result = self.extension.execute('fake_async_command', param='v2') @@ -149,6 +166,7 @@ class TestExtensionDecorators(test_base.BaseTestCase): result.command_status) self.assertIsInstance(result.command_error, ExecutionError) self.assertEqual(None, result.command_result) + self.agent.force_heartbeat.assert_called_once_with() def test_async_command_name(self): self.assertEqual( @@ -164,18 +182,24 @@ class TestExtensionDecorators(test_base.BaseTestCase): result.command_status) self.assertEqual(None, result.command_error) self.assertEqual('v1', result.command_result) + # no need to force heartbeat on a sync command + self.assertEqual(0, self.agent.force_heartbeat.call_count) def test_sync_command_validation_failure(self): self.assertRaises(errors.InvalidCommandParamsError, self.extension.execute, 'fake_sync_command', is_valid=False) + # validation is synchronous, no need to force a heartbeat + self.assertEqual(0, self.agent.force_heartbeat.call_count) def test_sync_command_execution_failure(self): self.assertRaises(ExecutionError, self.extension.execute, 'fake_sync_command', param='v2') + # no need to force heartbeat on a sync command + self.assertEqual(0, self.agent.force_heartbeat.call_count) def test_sync_command_name(self): self.assertEqual(