From dcd6e3e0dcadda84dcbf478c71460e4287e01c7b Mon Sep 17 00:00:00 2001
From: Russell Haering <russellhaering@gmail.com>
Date: Sun, 13 Apr 2014 17:37:57 -0700
Subject: [PATCH] Add a @sync_command() decorator

Instead of special-casing syncronous commands, add a decorator similar
to @async_command(), which can be used to decorate extension methods
for execution as a synchronous command.

Change-Id: I1b27f179f667cb065bcffd71ae0f303b05d33b82
---
 ironic_python_agent/extensions/base.py       | 34 ++++++---
 ironic_python_agent/tests/extensions/base.py | 80 ++++++++++++++++++++
 ironic_python_agent/tests/extensions/flow.py |  1 +
 3 files changed, 106 insertions(+), 9 deletions(-)

diff --git a/ironic_python_agent/extensions/base.py b/ironic_python_agent/extensions/base.py
index 54fde2c6e..14a735503 100644
--- a/ironic_python_agent/extensions/base.py
+++ b/ironic_python_agent/extensions/base.py
@@ -127,15 +127,7 @@ class BaseAgentExtension(object):
             raise errors.InvalidCommandError(
                 'Unknown command: {0}'.format(command_name))
 
-        result = self.command_map[command_name](command_name, **kwargs)
-
-        # In order to enable extremely succinct synchronous commands, we allow
-        # them to return a value directly, and we'll handle wrapping it up in a
-        # SyncCommandResult
-        if not isinstance(result, BaseCommandResult):
-            result = SyncCommandResult(command_name, kwargs, True, result)
-
-        return result
+        return self.command_map[command_name](command_name, **kwargs)
 
     def check_cmd_presence(self, ext_obj, ext, cmd):
         if not (hasattr(ext_obj, 'execute') and hasattr(ext_obj, 'command_map')
@@ -217,3 +209,27 @@ def async_command(validator=None):
                                       bound_func).start()
         return wrapper
     return async_decorator
+
+
+def sync_command(validator=None):
+    """Decorate a method in order to wrap up its return value in a
+    SyncCommandResult. For consistency with @async_command() can also accept a
+    validator which will be used to validate input, although a synchronous
+    command can also choose to implement validation inline.
+    """
+    def sync_decorator(func):
+        @functools.wraps(func)
+        def wrapper(self, command_name, **command_params):
+            # Run a validator before passing everything off to async.
+            # validators should raise exceptions or return silently.
+            if validator:
+                validator(self, **command_params)
+
+            result = func(self, command_name, **command_params)
+            return SyncCommandResult(command_name,
+                                     command_params,
+                                     True,
+                                     result)
+
+        return wrapper
+    return sync_decorator
diff --git a/ironic_python_agent/tests/extensions/base.py b/ironic_python_agent/tests/extensions/base.py
index 8eb915540..24530a199 100644
--- a/ironic_python_agent/tests/extensions/base.py
+++ b/ironic_python_agent/tests/extensions/base.py
@@ -20,9 +20,33 @@ from ironic_python_agent import errors
 from ironic_python_agent.extensions import base
 
 
+def _fake_validator(ext, **kwargs):
+    if not kwargs.get('is_valid', True):
+        raise errors.InvalidCommandParamsError('error')
+
+
+class ExecutionError(errors.RESTError):
+    def __init__(self):
+        super(ExecutionError, self).__init__('failed')
+
+
 class FakeExtension(base.BaseAgentExtension):
     def __init__(self):
         super(FakeExtension, self).__init__('FAKE')
+        self.command_map['fake_async_command'] = self.fake_async_command
+        self.command_map['fake_sync_command'] = self.fake_sync_command
+
+    @base.async_command(_fake_validator)
+    def fake_async_command(self, command_name, is_valid=False, param=None):
+        if param == 'v2':
+            raise ExecutionError()
+        return param
+
+    @base.sync_command(_fake_validator)
+    def fake_sync_command(self, command_name, is_valid=False, param=None):
+        if param == 'v2':
+            raise ExecutionError()
+        return param
 
 
 class FakeAgent(base.ExecuteCommandMixin):
@@ -90,3 +114,59 @@ class TestExecuteCommandMixin(test_base.BaseTestCase):
         self.assertEqual(result.command_status,
                          base.AgentCommandStatus.FAILED)
         self.assertEqual(result.command_error, msg)
+
+
+class TestExtensionDecorators(test_base.BaseTestCase):
+    def setUp(self):
+        super(TestExtensionDecorators, self).setUp()
+        self.extension = FakeExtension()
+
+    def test_async_command_success(self):
+        result = self.extension.execute('fake_async_command', param='v1')
+        self.assertIsInstance(result, base.AsyncCommandResult)
+        result.join()
+        self.assertEqual('fake_async_command', result.command_name)
+        self.assertEqual({'param': 'v1'}, result.command_params)
+        self.assertEqual(base.AgentCommandStatus.SUCCEEDED,
+                         result.command_status)
+        self.assertEqual(None, result.command_error)
+        self.assertEqual('v1', result.command_result)
+
+    def test_async_command_validation_failure(self):
+        self.assertRaises(errors.InvalidCommandParamsError,
+                          self.extension.execute,
+                          'fake_async_command',
+                          is_valid=False)
+
+    def test_async_command_execution_failure(self):
+        result = self.extension.execute('fake_async_command', param='v2')
+        self.assertIsInstance(result, base.AsyncCommandResult)
+        result.join()
+        self.assertEqual('fake_async_command', result.command_name)
+        self.assertEqual({'param': 'v2'}, result.command_params)
+        self.assertEqual(base.AgentCommandStatus.FAILED,
+                         result.command_status)
+        self.assertIsInstance(result.command_error, ExecutionError)
+        self.assertEqual(None, result.command_result)
+
+    def test_sync_command_success(self):
+        result = self.extension.execute('fake_sync_command', param='v1')
+        self.assertIsInstance(result, base.SyncCommandResult)
+        self.assertEqual('fake_sync_command', result.command_name)
+        self.assertEqual({'param': 'v1'}, result.command_params)
+        self.assertEqual(base.AgentCommandStatus.SUCCEEDED,
+                         result.command_status)
+        self.assertEqual(None, result.command_error)
+        self.assertEqual('v1', result.command_result)
+
+    def test_sync_command_validation_failure(self):
+        self.assertRaises(errors.InvalidCommandParamsError,
+                          self.extension.execute,
+                          'fake_sync_command',
+                          is_valid=False)
+
+    def test_sync_command_execution_failure(self):
+        self.assertRaises(ExecutionError,
+                          self.extension.execute,
+                          'fake_sync_command',
+                          param='v2')
diff --git a/ironic_python_agent/tests/extensions/flow.py b/ironic_python_agent/tests/extensions/flow.py
index d8b26ccba..7588a4d2c 100644
--- a/ironic_python_agent/tests/extensions/flow.py
+++ b/ironic_python_agent/tests/extensions/flow.py
@@ -45,6 +45,7 @@ class FakeExtension(base.BaseAgentExtension):
     def sleep(self, command_name, sleep_info=None):
         time.sleep(sleep_info['time'])
 
+    @base.sync_command()
     def sync_sleep(self, command_name, sleep_info=None):
         time.sleep(sleep_info['time'])