Move API parameter parsing from engine to API

API parameter parsing should be done in the API rather than the engine.

Change-Id: I133f419bbb73964b461572d8cadd300233af9c12
This commit is contained in:
Jason Dunsmore 2014-04-17 16:17:22 -05:00
parent beb76daf44
commit 4169c1bd8c
6 changed files with 138 additions and 136 deletions

View File

@ -218,13 +218,14 @@ class StackController(object):
"""
data = InstantiationData(body)
args = util.extract_args(data.args())
result = self.rpc_client.preview_stack(req.context,
data.stack_name(),
data.template(),
data.environment(),
data.files(),
data.args())
args)
formatted_stack = stacks_view.format_stack(req, result)
return {'stack': formatted_stack}
@ -235,13 +236,14 @@ class StackController(object):
Create a new stack
"""
data = InstantiationData(body)
args = util.extract_args(data.args())
result = self.rpc_client.create_stack(req.context,
data.stack_name(),
data.template(),
data.environment(),
data.files(),
data.args())
args)
formatted_stack = stacks_view.format_stack(
req,
@ -303,13 +305,14 @@ class StackController(object):
Update an existing stack with a new template and/or parameters
"""
data = InstantiationData(body)
args = util.extract_args(data.args())
self.rpc_client.update_stack(req.context,
identity,
data.template(),
data.environment(),
data.files(),
data.args())
args)
raise exc.HTTPAccepted()

View File

@ -16,6 +16,12 @@ from functools import wraps
from webob import exc
from heat.common import identifier
from heat.common import template_format
from heat.openstack.common.gettextutils import _
from heat.openstack.common import log as logging
from heat.rpc import api
logger = logging.getLogger(__name__)
def policy_enforce(handler):
@ -99,3 +105,46 @@ def get_allowed_params(params, whitelist):
allowed_params[key] = value
return allowed_params
def extract_args(params):
'''
Extract any arguments passed as parameters through the API and return them
as a dictionary. This allows us to filter the passed args and do type
conversion where appropriate
'''
kwargs = {}
timeout_mins = params.get(api.PARAM_TIMEOUT)
if timeout_mins not in ('0', 0, None):
try:
timeout = int(timeout_mins)
except (ValueError, TypeError):
logger.exception(_('Timeout conversion failed'))
else:
if timeout > 0:
kwargs[api.PARAM_TIMEOUT] = timeout
else:
raise ValueError(_('Invalid timeout value %s') % timeout)
if api.PARAM_DISABLE_ROLLBACK in params:
disable_rollback = params.get(api.PARAM_DISABLE_ROLLBACK)
if str(disable_rollback).lower() == 'true':
kwargs[api.PARAM_DISABLE_ROLLBACK] = True
elif str(disable_rollback).lower() == 'false':
kwargs[api.PARAM_DISABLE_ROLLBACK] = False
else:
raise ValueError(_('Unexpected value for parameter'
' %(name)s : %(value)s') %
dict(name=api.PARAM_DISABLE_ROLLBACK,
value=disable_rollback))
adopt_data = params.get(api.PARAM_ADOPT_STACK_DATA)
if adopt_data:
adopt_data = template_format.simple_parse(adopt_data)
if not isinstance(adopt_data, dict):
raise ValueError(
_('Unexpected adopt data "%s". Adopt data must be a dict.')
% adopt_data)
kwargs[api.PARAM_ADOPT_STACK_DATA] = adopt_data
return kwargs

View File

@ -11,7 +11,6 @@
# License for the specific language governing permissions and limitations
# under the License.
from heat.common import template_format
from heat.engine import constraints as constr
from heat.openstack.common.gettextutils import _
from heat.openstack.common import log as logging
@ -21,49 +20,6 @@ from heat.rpc import api
logger = logging.getLogger(__name__)
def extract_args(params):
'''
Extract any arguments passed as parameters through the API and return them
as a dictionary. This allows us to filter the passed args and do type
conversion where appropriate
'''
kwargs = {}
timeout_mins = params.get(api.PARAM_TIMEOUT)
if timeout_mins not in ('0', 0, None):
try:
timeout = int(timeout_mins)
except (ValueError, TypeError):
logger.exception(_('Timeout conversion failed'))
else:
if timeout > 0:
kwargs[api.PARAM_TIMEOUT] = timeout
else:
raise ValueError(_('Invalid timeout value %s') % timeout)
if api.PARAM_DISABLE_ROLLBACK in params:
disable_rollback = params.get(api.PARAM_DISABLE_ROLLBACK)
if str(disable_rollback).lower() == 'true':
kwargs[api.PARAM_DISABLE_ROLLBACK] = True
elif str(disable_rollback).lower() == 'false':
kwargs[api.PARAM_DISABLE_ROLLBACK] = False
else:
raise ValueError(_('Unexpected value for parameter'
' %(name)s : %(value)s') %
dict(name=api.PARAM_DISABLE_ROLLBACK,
value=disable_rollback))
adopt_data = params.get(api.PARAM_ADOPT_STACK_DATA)
if adopt_data:
adopt_data = template_format.simple_parse(adopt_data)
if not isinstance(adopt_data, dict):
raise ValueError(
_('Unexpected adopt data "%s". Adopt data must be a dict.')
% adopt_data)
kwargs[api.PARAM_ADOPT_STACK_DATA] = adopt_data
return kwargs
def format_stack_outputs(stack, outputs):
'''
Return a representation of the given output template for the given stack

View File

@ -464,9 +464,8 @@ class EngineService(service.Service):
tmpl = parser.Template(template, files=files)
self._validate_new_stack(cnxt, stack_name, tmpl)
common_params = api.extract_args(args)
env = environment.Environment(params)
stack = parser.Stack(cnxt, stack_name, tmpl, env, **common_params)
stack = parser.Stack(cnxt, stack_name, tmpl, env, **args)
self._validate_deferred_auth_context(cnxt, stack)
stack.validate()
@ -508,11 +507,8 @@ class EngineService(service.Service):
tmpl = parser.Template(template, files=files)
self._validate_new_stack(cnxt, stack_name, tmpl)
# Extract the common query parameters
common_params = api.extract_args(args)
env = environment.Environment(params)
stack = parser.Stack(cnxt, stack_name, tmpl,
env, **common_params)
stack = parser.Stack(cnxt, stack_name, tmpl, env, **args)
self._validate_deferred_auth_context(cnxt, stack)
@ -563,12 +559,10 @@ class EngineService(service.Service):
raise exception.RequestLimitExceeded(
message=exception.StackResourceLimitExceeded.msg_fmt)
stack_name = current_stack.name
common_params = api.extract_args(args)
common_params.setdefault(rpc_api.PARAM_TIMEOUT,
current_stack.timeout_mins)
args.setdefault(rpc_api.PARAM_TIMEOUT, current_stack.timeout_mins)
env = environment.Environment(params)
updated_stack = parser.Stack(cnxt, stack_name, tmpl,
env, **common_params)
env, **args)
updated_stack.parameters.set_stack_id(current_stack.identifier())
self._validate_deferred_auth_context(cnxt, updated_stack)

View File

@ -11,9 +11,11 @@
# License for the specific language governing permissions and limitations
# under the License.
import json
import mock
from webob import exc
import six
from heat.api.openstack.v1 import util
from heat.common import context
@ -118,3 +120,80 @@ class TestPolicyEnforce(HeatTestCase):
self.assertRaises(exc.HTTPForbidden,
self.controller.an_action,
self.req, tenant_id='foo')
class TestExtractArgs(HeatTestCase):
def test_timeout_extract(self):
p = {'timeout_mins': '5'}
args = util.extract_args(p)
self.assertEqual(5, args['timeout_mins'])
def test_timeout_extract_zero(self):
p = {'timeout_mins': '0'}
args = util.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_garbage(self):
p = {'timeout_mins': 'wibble'}
args = util.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_none(self):
p = {'timeout_mins': None}
args = util.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_negative(self):
p = {'timeout_mins': '-100'}
error = self.assertRaises(ValueError, util.extract_args, p)
self.assertIn('Invalid timeout value', six.text_type(error))
def test_timeout_extract_not_present(self):
args = util.extract_args({})
self.assertNotIn('timeout_mins', args)
def test_adopt_stack_data_extract_present(self):
p = {'adopt_stack_data': json.dumps({'Resources': {}})}
args = util.extract_args(p)
self.assertTrue(args.get('adopt_stack_data'))
def test_invalid_adopt_stack_data(self):
p = {'adopt_stack_data': json.dumps("foo")}
error = self.assertRaises(ValueError, util.extract_args, p)
self.assertEqual(
'Unexpected adopt data "foo". Adopt data must be a dict.',
six.text_type(error))
def test_adopt_stack_data_extract_not_present(self):
args = util.extract_args({})
self.assertNotIn('adopt_stack_data', args)
def test_disable_rollback_extract_true(self):
args = util.extract_args({'disable_rollback': True})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
args = util.extract_args({'disable_rollback': 'True'})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
args = util.extract_args({'disable_rollback': 'true'})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
def test_disable_rollback_extract_false(self):
args = util.extract_args({'disable_rollback': False})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
args = util.extract_args({'disable_rollback': 'False'})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
args = util.extract_args({'disable_rollback': 'false'})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
def test_disable_rollback_extract_bad(self):
self.assertRaises(ValueError, util.extract_args,
{'disable_rollback': 'bad'})

View File

@ -12,11 +12,9 @@
# under the License.
from datetime import datetime
import json
import uuid
import mock
import six
from heat.common.identifier import EventIdentifier
from heat.common import template_format
@ -31,83 +29,6 @@ from heat.tests import generic_resource as generic_rsrc
from heat.tests import utils
class EngineApiTest(HeatTestCase):
def test_timeout_extract(self):
p = {'timeout_mins': '5'}
args = api.extract_args(p)
self.assertEqual(5, args['timeout_mins'])
def test_timeout_extract_zero(self):
p = {'timeout_mins': '0'}
args = api.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_garbage(self):
p = {'timeout_mins': 'wibble'}
args = api.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_none(self):
p = {'timeout_mins': None}
args = api.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_negative(self):
p = {'timeout_mins': '-100'}
error = self.assertRaises(ValueError, api.extract_args, p)
self.assertIn('Invalid timeout value', six.text_type(error))
def test_timeout_extract_not_present(self):
args = api.extract_args({})
self.assertNotIn('timeout_mins', args)
def test_adopt_stack_data_extract_present(self):
p = {'adopt_stack_data': json.dumps({'Resources': {}})}
args = api.extract_args(p)
self.assertTrue(args.get('adopt_stack_data'))
def test_invalid_adopt_stack_data(self):
p = {'adopt_stack_data': json.dumps("foo")}
error = self.assertRaises(ValueError, api.extract_args, p)
self.assertEqual(
'Unexpected adopt data "foo". Adopt data must be a dict.',
six.text_type(error))
def test_adopt_stack_data_extract_not_present(self):
args = api.extract_args({})
self.assertNotIn('adopt_stack_data', args)
def test_disable_rollback_extract_true(self):
args = api.extract_args({'disable_rollback': True})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
args = api.extract_args({'disable_rollback': 'True'})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
args = api.extract_args({'disable_rollback': 'true'})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
def test_disable_rollback_extract_false(self):
args = api.extract_args({'disable_rollback': False})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
args = api.extract_args({'disable_rollback': 'False'})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
args = api.extract_args({'disable_rollback': 'false'})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
def test_disable_rollback_extract_bad(self):
self.assertRaises(ValueError, api.extract_args,
{'disable_rollback': 'bad'})
class FormatTest(HeatTestCase):
def setUp(self):
super(FormatTest, self).setUp()