From 082cf29cec31efda2d63eeb9dda8c74e396be6eb Mon Sep 17 00:00:00 2001
From: Jim Rollenhagen <jim@jimrollenhagen.com>
Date: Wed, 1 Oct 2014 10:59:09 -0700
Subject: [PATCH] Force heartbeat immediately after async command completes

This change passes the agent object into extensions, such that
they may call agent methods as needed. It also causes async
commands to force a heartbeat immediately after completing the
command. This allows Ironic to get a heartbeat and continue
work as quickly as possible, while also allowing deployers to
configure Ironic (agent) to heartbeat less often.

Change-Id: Ib3c3a43dfd0ed4ed51b7d52ac099f01181ca822f
---
 ironic_python_agent/agent.py                 |  7 ++++++
 ironic_python_agent/extensions/base.py       | 14 ++++++++---
 ironic_python_agent/extensions/standby.py    |  4 +--
 ironic_python_agent/tests/agent.py           |  3 +--
 ironic_python_agent/tests/extensions/base.py | 26 +++++++++++++++++++-
 5 files changed, 45 insertions(+), 9 deletions(-)

diff --git a/ironic_python_agent/agent.py b/ironic_python_agent/agent.py
index 079a323c9..5bf2eeca9 100644
--- a/ironic_python_agent/agent.py
+++ b/ironic_python_agent/agent.py
@@ -121,6 +121,9 @@ class IronicPythonAgentHeartbeater(threading.Thread):
             self.error_delay = min(self.error_delay * self.backoff_factor,
                                    self.max_delay)
 
+    def force_heartbeat(self):
+        os.write(self.writer, 'b')
+
     def stop(self):
         """Stop the heartbeat thread."""
         if self.writer is not None:
@@ -140,6 +143,7 @@ class IronicPythonAgent(base.ExecuteCommandMixin):
             namespace='ironic_python_agent.extensions',
             invoke_on_load=True,
             propagate_map_exceptions=True,
+            invoke_kwds={'agent': self},
         )
         self.api_url = api_url
         self.driver_name = driver_name
@@ -262,6 +266,9 @@ class IronicPythonAgent(base.ExecuteCommandMixin):
             raise errors.RequestedObjectNotFoundError('Command Result',
                                                       result_id)
 
+    def force_heartbeat(self):
+        self.heartbeater.force_heartbeat()
+
     def run(self):
         """Run the Ironic Python Agent."""
         # Get the UUID so we can heartbeat to Ironic. Raises LookupNodeError
diff --git a/ironic_python_agent/extensions/base.py b/ironic_python_agent/extensions/base.py
index 9266cbbb3..14021906f 100644
--- a/ironic_python_agent/extensions/base.py
+++ b/ironic_python_agent/extensions/base.py
@@ -67,8 +67,10 @@ class AsyncCommandResult(BaseCommandResult):
 
     :param execute_method: a callable to be executed asynchronously
     """
-    def __init__(self, command_name, command_params, execute_method):
+    def __init__(self, command_name, command_params, execute_method,
+                 agent=None):
         super(AsyncCommandResult, self).__init__(command_name, command_params)
+        self.agent = agent
         self.execute_method = execute_method
         self.command_state_lock = threading.Lock()
 
@@ -98,7 +100,6 @@ class AsyncCommandResult(BaseCommandResult):
             with self.command_state_lock:
                 self.command_result = result
                 self.command_status = AgentCommandStatus.SUCCEEDED
-
         except Exception as e:
             if not isinstance(e, errors.RESTError):
                 e = errors.CommandExecutionError(str(e))
@@ -106,11 +107,15 @@ class AsyncCommandResult(BaseCommandResult):
             with self.command_state_lock:
                 self.command_error = e
                 self.command_status = AgentCommandStatus.FAILED
+        finally:
+            if self.agent:
+                self.agent.force_heartbeat()
 
 
 class BaseAgentExtension(object):
-    def __init__(self):
+    def __init__(self, agent=None):
         super(BaseAgentExtension, self).__init__()
+        self.agent = agent
         self.log = log.getLogger(__name__)
         self.command_map = dict(
             (v.command_name, v)
@@ -207,7 +212,8 @@ def async_command(command_name, validator=None):
 
             return AsyncCommandResult(command_name,
                                       command_params,
-                                      bound_func).start()
+                                      bound_func,
+                                      agent=self.agent).start()
         return wrapper
     return async_decorator
 
diff --git a/ironic_python_agent/extensions/standby.py b/ironic_python_agent/extensions/standby.py
index 6d08672bb..eb8faa8c2 100644
--- a/ironic_python_agent/extensions/standby.py
+++ b/ironic_python_agent/extensions/standby.py
@@ -172,8 +172,8 @@ def _validate_image_info(ext, image_info=None, **kwargs):
 
 
 class StandbyExtension(base.BaseAgentExtension):
-    def __init__(self):
-        super(StandbyExtension, self).__init__()
+    def __init__(self, agent=None):
+        super(StandbyExtension, self).__init__(agent=agent)
 
         self.cached_image_id = None
 
diff --git a/ironic_python_agent/tests/agent.py b/ironic_python_agent/tests/agent.py
index b2f46d996..747cfe396 100644
--- a/ironic_python_agent/tests/agent.py
+++ b/ironic_python_agent/tests/agent.py
@@ -49,8 +49,7 @@ def foo_execute(*args, **kwargs):
 
 
 class FakeExtension(base.BaseAgentExtension):
-    def __init__(self):
-        super(FakeExtension, self).__init__()
+    pass
 
 
 class TestHeartbeater(test_base.BaseTestCase):
diff --git a/ironic_python_agent/tests/extensions/base.py b/ironic_python_agent/tests/extensions/base.py
index 4be19d94f..30ce34a07 100644
--- a/ironic_python_agent/tests/extensions/base.py
+++ b/ironic_python_agent/tests/extensions/base.py
@@ -120,7 +120,9 @@ class TestExecuteCommandMixin(test_base.BaseTestCase):
 class TestExtensionDecorators(test_base.BaseTestCase):
     def setUp(self):
         super(TestExtensionDecorators, self).setUp()
-        self.extension = FakeExtension()
+        self.agent = FakeAgent()
+        self.agent.force_heartbeat = mock.Mock()
+        self.extension = FakeExtension(agent=self.agent)
 
     def test_async_command_success(self):
         result = self.extension.execute('fake_async_command', param='v1')
@@ -132,12 +134,27 @@ class TestExtensionDecorators(test_base.BaseTestCase):
                          result.command_status)
         self.assertEqual(None, result.command_error)
         self.assertEqual('v1', result.command_result)
+        self.agent.force_heartbeat.assert_called_once_with()
+
+    def test_async_command_success_without_agent(self):
+        extension = FakeExtension(agent=None)
+        result = extension.execute('fake_async_command', param='v1')
+        self.assertIsInstance(result, base.AsyncCommandResult)
+        result.join()
+        self.assertEqual('fake_async_command', result.command_name)
+        self.assertEqual({'param': 'v1'}, result.command_params)
+        self.assertEqual(base.AgentCommandStatus.SUCCEEDED,
+                         result.command_status)
+        self.assertEqual(None, result.command_error)
+        self.assertEqual('v1', result.command_result)
 
     def test_async_command_validation_failure(self):
         self.assertRaises(errors.InvalidCommandParamsError,
                           self.extension.execute,
                           'fake_async_command',
                           is_valid=False)
+        # validation is synchronous, no need to force a heartbeat
+        self.assertEqual(0, self.agent.force_heartbeat.call_count)
 
     def test_async_command_execution_failure(self):
         result = self.extension.execute('fake_async_command', param='v2')
@@ -149,6 +166,7 @@ class TestExtensionDecorators(test_base.BaseTestCase):
                          result.command_status)
         self.assertIsInstance(result.command_error, ExecutionError)
         self.assertEqual(None, result.command_result)
+        self.agent.force_heartbeat.assert_called_once_with()
 
     def test_async_command_name(self):
         self.assertEqual(
@@ -164,18 +182,24 @@ class TestExtensionDecorators(test_base.BaseTestCase):
                          result.command_status)
         self.assertEqual(None, result.command_error)
         self.assertEqual('v1', result.command_result)
+        # no need to force heartbeat on a sync command
+        self.assertEqual(0, self.agent.force_heartbeat.call_count)
 
     def test_sync_command_validation_failure(self):
         self.assertRaises(errors.InvalidCommandParamsError,
                           self.extension.execute,
                           'fake_sync_command',
                           is_valid=False)
+        # validation is synchronous, no need to force a heartbeat
+        self.assertEqual(0, self.agent.force_heartbeat.call_count)
 
     def test_sync_command_execution_failure(self):
         self.assertRaises(ExecutionError,
                           self.extension.execute,
                           'fake_sync_command',
                           param='v2')
+        # no need to force heartbeat on a sync command
+        self.assertEqual(0, self.agent.force_heartbeat.call_count)
 
     def test_sync_command_name(self):
         self.assertEqual(