Move API parameter parsing from OpenStack API to engine

In commit 4169c1bd8c, the API parameter
parsing was moved to the OpenStack API.  Since then, inputs to the CFN
API were not being validated, creating a security hole.

Change-Id: I21920591075bcefbe695316dab6605afd6f4ec64
Closes-Bug: #1317667
This commit is contained in:
Jason Dunsmore 2014-05-09 09:31:12 -05:00 committed by Anderson Mesquita
parent 42acc9d175
commit 5d1b4ab3e9
8 changed files with 186 additions and 153 deletions

View File

@ -21,6 +21,7 @@ from heat.api.openstack.v1 import util
from heat.api.openstack.v1.views import stacks_view from heat.api.openstack.v1.views import stacks_view
from heat.common import environment_format from heat.common import environment_format
from heat.common import identifier from heat.common import identifier
from heat.common import param_utils
from heat.common import serializers from heat.common import serializers
from heat.common import template_format from heat.common import template_format
from heat.common import urlfetch from heat.common import urlfetch
@ -167,8 +168,8 @@ class StackController(object):
filter_params = util.get_allowed_params(req.params, filter_whitelist) filter_params = util.get_allowed_params(req.params, filter_whitelist)
if engine_api.PARAM_SHOW_DELETED in params: if engine_api.PARAM_SHOW_DELETED in params:
show_del = util.extract_bool(params[engine_api.PARAM_SHOW_DELETED]) params[engine_api.PARAM_SHOW_DELETED] = param_utils.extract_bool(
params[engine_api.PARAM_SHOW_DELETED] = show_del params[engine_api.PARAM_SHOW_DELETED])
if not filter_params: if not filter_params:
filter_params = None filter_params = None
@ -223,14 +224,13 @@ class StackController(object):
""" """
data = InstantiationData(body) data = InstantiationData(body)
args = util.extract_args(data.args())
result = self.rpc_client.preview_stack(req.context, result = self.rpc_client.preview_stack(req.context,
data.stack_name(), data.stack_name(),
data.template(), data.template(),
data.environment(), data.environment(),
data.files(), data.files(),
args) data.args())
formatted_stack = stacks_view.format_stack(req, result) formatted_stack = stacks_view.format_stack(req, result)
return {'stack': formatted_stack} return {'stack': formatted_stack}
@ -241,14 +241,13 @@ class StackController(object):
Create a new stack Create a new stack
""" """
data = InstantiationData(body) data = InstantiationData(body)
args = util.extract_args(data.args())
result = self.rpc_client.create_stack(req.context, result = self.rpc_client.create_stack(req.context,
data.stack_name(), data.stack_name(),
data.template(), data.template(),
data.environment(), data.environment(),
data.files(), data.files(),
args) data.args())
formatted_stack = stacks_view.format_stack( formatted_stack = stacks_view.format_stack(
req, req,
@ -310,14 +309,13 @@ class StackController(object):
Update an existing stack with a new template and/or parameters Update an existing stack with a new template and/or parameters
""" """
data = InstantiationData(body) data = InstantiationData(body)
args = util.extract_args(data.args())
self.rpc_client.update_stack(req.context, self.rpc_client.update_stack(req.context,
identity, identity,
data.template(), data.template(),
data.environment(), data.environment(),
data.files(), data.files(),
args) data.args())
raise exc.HTTPAccepted() raise exc.HTTPAccepted()

View File

@ -16,13 +16,7 @@ from functools import wraps
from webob import exc from webob import exc
from heat.common import identifier from heat.common import identifier
from heat.common import template_format
from heat.openstack.common.gettextutils import _ from heat.openstack.common.gettextutils import _
from heat.openstack.common import log as logging
from heat.openstack.common import strutils
from heat.rpc import api
logger = logging.getLogger(__name__)
def policy_enforce(handler): def policy_enforce(handler):
@ -106,49 +100,3 @@ def get_allowed_params(params, whitelist):
allowed_params[key] = value allowed_params[key] = value
return allowed_params return allowed_params
def extract_args(params):
'''
Extract any arguments passed as parameters through the API and return them
as a dictionary. This allows us to filter the passed args and do type
conversion where appropriate
'''
kwargs = {}
timeout_mins = params.get(api.PARAM_TIMEOUT)
if timeout_mins not in ('0', 0, None):
try:
timeout = int(timeout_mins)
except (ValueError, TypeError):
logger.exception(_('Timeout conversion failed'))
else:
if timeout > 0:
kwargs[api.PARAM_TIMEOUT] = timeout
else:
raise ValueError(_('Invalid timeout value %s') % timeout)
if api.PARAM_DISABLE_ROLLBACK in params:
disable_rollback = extract_bool(params[api.PARAM_DISABLE_ROLLBACK])
kwargs[api.PARAM_DISABLE_ROLLBACK] = disable_rollback
adopt_data = params.get(api.PARAM_ADOPT_STACK_DATA)
if adopt_data:
adopt_data = template_format.simple_parse(adopt_data)
if not isinstance(adopt_data, dict):
raise ValueError(
_('Unexpected adopt data "%s". Adopt data must be a dict.')
% adopt_data)
kwargs[api.PARAM_ADOPT_STACK_DATA] = adopt_data
return kwargs
def extract_bool(subject):
'''
Convert any true/false string to its corresponding boolean value,
regardless of case.
'''
if str(subject).lower() not in ('true', 'false'):
raise ValueError(_('Unrecognized value "%(value)s, acceptable values '
'are: true, false.') % {'value': subject})
return strutils.bool_from_string(subject, strict=True)

View File

@ -0,0 +1,26 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from heat.openstack.common.gettextutils import _
from heat.openstack.common import strutils
def extract_bool(subject):
'''
Convert any true/false string to its corresponding boolean value,
regardless of case.
'''
if str(subject).lower() not in ('true', 'false'):
raise ValueError(_('Unrecognized value "%(value)s, acceptable values '
'are: true, false.') % {'value': subject})
return strutils.bool_from_string(subject, strict=True)

View File

@ -11,6 +11,8 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from heat.common import param_utils
from heat.common import template_format
from heat.engine import constraints as constr from heat.engine import constraints as constr
from heat.openstack.common.gettextutils import _ from heat.openstack.common.gettextutils import _
from heat.openstack.common import log as logging from heat.openstack.common import log as logging
@ -20,6 +22,46 @@ from heat.rpc import api
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def extract_args(params):
'''
Extract any arguments passed as parameters through the API and return them
as a dictionary. This allows us to filter the passed args and do type
conversion where appropriate
'''
kwargs = {}
timeout_mins = params.get(api.PARAM_TIMEOUT)
if timeout_mins not in ('0', 0, None):
try:
timeout = int(timeout_mins)
except (ValueError, TypeError):
logger.exception(_('Timeout conversion failed'))
else:
if timeout > 0:
kwargs[api.PARAM_TIMEOUT] = timeout
else:
raise ValueError(_('Invalid timeout value %s') % timeout)
if api.PARAM_DISABLE_ROLLBACK in params:
disable_rollback = param_utils.extract_bool(
params[api.PARAM_DISABLE_ROLLBACK])
kwargs[api.PARAM_DISABLE_ROLLBACK] = disable_rollback
if api.PARAM_SHOW_DELETED in params:
params[api.PARAM_SHOW_DELETED] = param_utils.extract_bool(
params[api.PARAM_SHOW_DELETED])
adopt_data = params.get(api.PARAM_ADOPT_STACK_DATA)
if adopt_data:
adopt_data = template_format.simple_parse(adopt_data)
if not isinstance(adopt_data, dict):
raise ValueError(
_('Unexpected adopt data "%s". Adopt data must be a dict.')
% adopt_data)
kwargs[api.PARAM_ADOPT_STACK_DATA] = adopt_data
return kwargs
def format_stack_outputs(stack, outputs): def format_stack_outputs(stack, outputs):
''' '''
Return a representation of the given output template for the given stack Return a representation of the given output template for the given stack

View File

@ -495,8 +495,9 @@ class EngineService(service.Service):
tmpl = parser.Template(template, files=files) tmpl = parser.Template(template, files=files)
self._validate_new_stack(cnxt, stack_name, tmpl) self._validate_new_stack(cnxt, stack_name, tmpl)
common_params = api.extract_args(args)
env = environment.Environment(params) env = environment.Environment(params)
stack = parser.Stack(cnxt, stack_name, tmpl, env, **args) stack = parser.Stack(cnxt, stack_name, tmpl, env, **common_params)
self._validate_deferred_auth_context(cnxt, stack) self._validate_deferred_auth_context(cnxt, stack)
stack.validate() stack.validate()
@ -538,8 +539,9 @@ class EngineService(service.Service):
tmpl = parser.Template(template, files=files) tmpl = parser.Template(template, files=files)
self._validate_new_stack(cnxt, stack_name, tmpl) self._validate_new_stack(cnxt, stack_name, tmpl)
common_params = api.extract_args(args)
env = environment.Environment(params) env = environment.Environment(params)
stack = parser.Stack(cnxt, stack_name, tmpl, env, **args) stack = parser.Stack(cnxt, stack_name, tmpl, env, **common_params)
self._validate_deferred_auth_context(cnxt, stack) self._validate_deferred_auth_context(cnxt, stack)
@ -590,10 +592,12 @@ class EngineService(service.Service):
raise exception.RequestLimitExceeded( raise exception.RequestLimitExceeded(
message=exception.StackResourceLimitExceeded.msg_fmt) message=exception.StackResourceLimitExceeded.msg_fmt)
stack_name = current_stack.name stack_name = current_stack.name
args.setdefault(rpc_api.PARAM_TIMEOUT, current_stack.timeout_mins) common_params = api.extract_args(args)
common_params.setdefault(rpc_api.PARAM_TIMEOUT,
current_stack.timeout_mins)
env = environment.Environment(params) env = environment.Environment(params)
updated_stack = parser.Stack(cnxt, stack_name, tmpl, updated_stack = parser.Stack(cnxt, stack_name, tmpl,
env, **args) env, **common_params)
updated_stack.parameters.set_stack_id(current_stack.identifier()) updated_stack.parameters.set_stack_id(current_stack.identifier())
self._validate_deferred_auth_context(cnxt, updated_stack) self._validate_deferred_auth_context(cnxt, updated_stack)

View File

@ -11,11 +11,9 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import json
import mock import mock
from webob import exc from webob import exc
import six
from heat.api.openstack.v1 import util from heat.api.openstack.v1 import util
from heat.common import context from heat.common import context
@ -120,90 +118,3 @@ class TestPolicyEnforce(HeatTestCase):
self.assertRaises(exc.HTTPForbidden, self.assertRaises(exc.HTTPForbidden,
self.controller.an_action, self.controller.an_action,
self.req, tenant_id='foo') self.req, tenant_id='foo')
class TestExtractArgs(HeatTestCase):
def test_timeout_extract(self):
p = {'timeout_mins': '5'}
args = util.extract_args(p)
self.assertEqual(5, args['timeout_mins'])
def test_timeout_extract_zero(self):
p = {'timeout_mins': '0'}
args = util.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_garbage(self):
p = {'timeout_mins': 'wibble'}
args = util.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_none(self):
p = {'timeout_mins': None}
args = util.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_negative(self):
p = {'timeout_mins': '-100'}
error = self.assertRaises(ValueError, util.extract_args, p)
self.assertIn('Invalid timeout value', six.text_type(error))
def test_timeout_extract_not_present(self):
args = util.extract_args({})
self.assertNotIn('timeout_mins', args)
def test_adopt_stack_data_extract_present(self):
p = {'adopt_stack_data': json.dumps({'Resources': {}})}
args = util.extract_args(p)
self.assertTrue(args.get('adopt_stack_data'))
def test_invalid_adopt_stack_data(self):
p = {'adopt_stack_data': json.dumps("foo")}
error = self.assertRaises(ValueError, util.extract_args, p)
self.assertEqual(
'Unexpected adopt data "foo". Adopt data must be a dict.',
six.text_type(error))
def test_adopt_stack_data_extract_not_present(self):
args = util.extract_args({})
self.assertNotIn('adopt_stack_data', args)
def test_disable_rollback_extract_true(self):
args = util.extract_args({'disable_rollback': True})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
args = util.extract_args({'disable_rollback': 'True'})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
args = util.extract_args({'disable_rollback': 'true'})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
def test_disable_rollback_extract_false(self):
args = util.extract_args({'disable_rollback': False})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
args = util.extract_args({'disable_rollback': 'False'})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
args = util.extract_args({'disable_rollback': 'false'})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
def test_disable_rollback_extract_bad(self):
self.assertRaises(ValueError, util.extract_args,
{'disable_rollback': 'bad'})
class TestExtractBool(HeatTestCase):
def test_extract_bool(self):
for value in ('True', 'true', 'TRUE', True):
self.assertTrue(util.extract_bool(value))
for value in ('False', 'false', 'FALSE', False):
self.assertFalse(util.extract_bool(value))
for value in ('foo', 't', 'f', 'yes', 'no', 'y', 'n', '1', '0', None):
self.assertRaises(ValueError, util.extract_bool, value)

View File

@ -0,0 +1,25 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from heat.common import param_utils
from heat.tests.common import HeatTestCase
class TestExtractBool(HeatTestCase):
def test_extract_bool(self):
for value in ('True', 'true', 'TRUE', True):
self.assertTrue(param_utils.extract_bool(value))
for value in ('False', 'false', 'FALSE', False):
self.assertFalse(param_utils.extract_bool(value))
for value in ('foo', 't', 'f', 'yes', 'no', 'y', 'n', '1', '0', None):
self.assertRaises(ValueError, param_utils.extract_bool, value)

View File

@ -14,7 +14,9 @@
from datetime import datetime from datetime import datetime
import uuid import uuid
import json
import mock import mock
import six
from heat.common.identifier import EventIdentifier from heat.common.identifier import EventIdentifier
from heat.common import template_format from heat.common import template_format
@ -790,3 +792,80 @@ class FormatSoftwareConfigDeploymentTest(HeatTestCase):
def test_format_software_deployment_none(self): def test_format_software_deployment_none(self):
self.assertIsNone(api.format_software_deployment(None)) self.assertIsNone(api.format_software_deployment(None))
class TestExtractArgs(HeatTestCase):
def test_timeout_extract(self):
p = {'timeout_mins': '5'}
args = api.extract_args(p)
self.assertEqual(5, args['timeout_mins'])
def test_timeout_extract_zero(self):
p = {'timeout_mins': '0'}
args = api.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_garbage(self):
p = {'timeout_mins': 'wibble'}
args = api.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_none(self):
p = {'timeout_mins': None}
args = api.extract_args(p)
self.assertNotIn('timeout_mins', args)
def test_timeout_extract_negative(self):
p = {'timeout_mins': '-100'}
error = self.assertRaises(ValueError, api.extract_args, p)
self.assertIn('Invalid timeout value', six.text_type(error))
def test_timeout_extract_not_present(self):
args = api.extract_args({})
self.assertNotIn('timeout_mins', args)
def test_adopt_stack_data_extract_present(self):
p = {'adopt_stack_data': json.dumps({'Resources': {}})}
args = api.extract_args(p)
self.assertTrue(args.get('adopt_stack_data'))
def test_invalid_adopt_stack_data(self):
p = {'adopt_stack_data': json.dumps("foo")}
error = self.assertRaises(ValueError, api.extract_args, p)
self.assertEqual(
'Unexpected adopt data "foo". Adopt data must be a dict.',
six.text_type(error))
def test_adopt_stack_data_extract_not_present(self):
args = api.extract_args({})
self.assertNotIn('adopt_stack_data', args)
def test_disable_rollback_extract_true(self):
args = api.extract_args({'disable_rollback': True})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
args = api.extract_args({'disable_rollback': 'True'})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
args = api.extract_args({'disable_rollback': 'true'})
self.assertIn('disable_rollback', args)
self.assertTrue(args.get('disable_rollback'))
def test_disable_rollback_extract_false(self):
args = api.extract_args({'disable_rollback': False})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
args = api.extract_args({'disable_rollback': 'False'})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
args = api.extract_args({'disable_rollback': 'false'})
self.assertIn('disable_rollback', args)
self.assertFalse(args.get('disable_rollback'))
def test_disable_rollback_extract_bad(self):
self.assertRaises(ValueError, api.extract_args,
{'disable_rollback': 'bad'})