Add the ability to validate default options value

Currently, the default options value provided are not validated, which
can trigger fun things like a IntOpt having a "foobar" as default value,
with code failing later when the option is being used.

This patch introduces a _validate_value method that can be used to
validate the default value to the expected type.

ConfigOpts.__call__ gains a validate_default_values arguments that can
trigger default value validation. The argument is False by default in
order to not break existing code.

Ultimately, we could in a later version validate the default at Opt
object creation time, which would be much better, but that may break
existing code, so we'll see.

Change-Id: I6c1998a3be3f1579139317cd85d56b42cabf89db
This commit is contained in:
Julien Danjou 2013-07-22 16:12:51 +02:00
parent b4252b9217
commit a289afed93
2 changed files with 91 additions and 9 deletions

View File

@ -533,6 +533,7 @@ class Opt(object):
""" """
multi = False multi = False
_convert_value = None _convert_value = None
_validate_value = None
def __init__(self, name, dest=None, short=None, default=None, def __init__(self, name, dest=None, short=None, default=None,
positional=False, metavar=None, help=None, positional=False, metavar=None, help=None,
@ -595,7 +596,8 @@ class Opt(object):
return namespace._get_value(names, return namespace._get_value(names,
self.multi, self.positional, self.multi, self.positional,
self._convert_value) self._convert_value,
self._validate_value)
def _add_to_cli(self, parser, group=None): def _add_to_cli(self, parser, group=None):
"""Makes the option available in the command line interface. """Makes the option available in the command line interface.
@ -783,13 +785,12 @@ class StrOpt(Opt):
choices=self.choices, choices=self.choices,
**kwargs) **kwargs)
def _convert_value(self, value): def _validate_value(self, value):
"""Validate a value, no actual conversion required.""" """Validate a value, no actual conversion required."""
if self.choices and value not in self.choices: if self.choices and value not in self.choices:
message = ('Invalid value: %r (choose from %s)' % message = ('Invalid value: %r (choose from %s)' %
(value, ', '.join(map(repr, self.choices)))) (value, ', '.join(map(repr, self.choices))))
raise ValueError(message) raise ValueError(message)
return value
class BoolOpt(Opt): class BoolOpt(Opt):
@ -811,6 +812,11 @@ class BoolOpt(Opt):
raise ValueError('positional boolean args not supported') raise ValueError('positional boolean args not supported')
super(BoolOpt, self).__init__(*args, **kwargs) super(BoolOpt, self).__init__(*args, **kwargs)
@staticmethod
def _validate_value(value):
if not isinstance(value, bool):
raise ValueError("Value is not a bool")
@staticmethod @staticmethod
def _convert_value(value): def _convert_value(value):
"""Convert a string value to a bool.""" """Convert a string value to a bool."""
@ -857,6 +863,11 @@ class IntOpt(Opt):
_convert_value = int _convert_value = int
@staticmethod
def _validate_value(value):
if not isinstance(value, int):
raise ValueError("Value is not an int")
def _get_argparse_kwargs(self, group, **kwargs): def _get_argparse_kwargs(self, group, **kwargs):
"""Extends the base argparse keyword dict for integer options.""" """Extends the base argparse keyword dict for integer options."""
return super(IntOpt, return super(IntOpt,
@ -869,6 +880,11 @@ class FloatOpt(Opt):
_convert_value = float _convert_value = float
@staticmethod
def _validate_value(value):
if not isinstance(value, float):
raise ValueError("Value is not an float")
def _get_argparse_kwargs(self, group, **kwargs): def _get_argparse_kwargs(self, group, **kwargs):
"""Extends the base argparse keyword dict for float options.""" """Extends the base argparse keyword dict for float options."""
return super(FloatOpt, self)._get_argparse_kwargs(group, return super(FloatOpt, self)._get_argparse_kwargs(group,
@ -883,6 +899,8 @@ class ListOpt(Opt):
is a list containing these strings. is a list containing these strings.
""" """
_validate_value = iter
@staticmethod @staticmethod
def _convert_value(value): def _convert_value(value):
"""Convert a string value to a list.""" """Convert a string value to a list."""
@ -911,6 +929,13 @@ class DictOpt(Opt):
value is a dictionary of these key/value pairs value is a dictionary of these key/value pairs
""" """
@staticmethod
def _validate_value(value):
try:
return dict(value)
except TypeError as e:
raise ValueError("Unable to convert value to a dict: %s" % str(e))
@staticmethod @staticmethod
def _convert_value(value): def _convert_value(value):
"""Split a line. """Split a line.
@ -1293,13 +1318,15 @@ class MultiConfigParser(object):
def get(self, names, multi=False): def get(self, names, multi=False):
return self._get(names, multi=multi) return self._get(names, multi=multi)
def _get(self, names, multi=False, normalized=False, convert_value=None): def _get(self, names, multi=False, normalized=False, convert_value=None,
validate_value=None):
"""Fetch a config file value from the parsed files. """Fetch a config file value from the parsed files.
:param names: a list of (section, name) tuples :param names: a list of (section, name) tuples
:param multi: a boolean indicating whether to return multiple values :param multi: a boolean indicating whether to return multiple values
:param normalized: whether to normalize group names to lowercase :param normalized: whether to normalize group names to lowercase
:param convert_value: callable to convert a string into the proper type :param convert_value: callable to convert a string into the proper type
:param validate_value: callable to validate value
""" """
rvalue = [] rvalue = []
@ -1317,6 +1344,9 @@ class MultiConfigParser(object):
continue continue
if name in sections[section]: if name in sections[section]:
val = [convert(v) for v in sections[section][name]] val = [convert(v) for v in sections[section][name]]
if validate_value:
for v in val:
validate_value(v)
if multi: if multi:
rvalue = val + rvalue rvalue = val + rvalue
else: else:
@ -1435,7 +1465,8 @@ class _Namespace(argparse.Namespace):
pass pass
raise KeyError raise KeyError
def _get_value(self, names, multi, positional, convert_value): def _get_value(self, names, multi, positional, convert_value,
validate_value):
"""Fetch a value from config files. """Fetch a value from config files.
Multiple names for a given configuration option may be supplied so Multiple names for a given configuration option may be supplied so
@ -1446,6 +1477,7 @@ class _Namespace(argparse.Namespace):
:param multi: a boolean indicating whether to return multiple values :param multi: a boolean indicating whether to return multiple values
:param positional: whether this is a positional option :param positional: whether this is a positional option
:param convert_value: callable to convert a string into the proper type :param convert_value: callable to convert a string into the proper type
:param validate_value: callable to validate the converted value
""" """
try: try:
return self._get_cli_value(names, positional) return self._get_cli_value(names, positional)
@ -1454,7 +1486,8 @@ class _Namespace(argparse.Namespace):
names = [(g if g is not None else 'DEFAULT', n) for g, n in names] names = [(g if g is not None else 'DEFAULT', n) for g, n in names]
values = self.parser._get(names, multi=multi, normalized=True, values = self.parser._get(names, multi=multi, normalized=True,
convert_value=convert_value) convert_value=convert_value,
validate_value=validate_value)
return values if multi else values[-1] return values if multi else values[-1]
@ -1524,6 +1557,7 @@ class ConfigOpts(collections.Mapping):
self._namespace = None self._namespace = None
self.__cache = {} self.__cache = {}
self._config_opts = [] self._config_opts = []
self._validate_default_values = False
def _pre_setup(self, project, prog, version, usage, default_config_files): def _pre_setup(self, project, prog, version, usage, default_config_files):
"""Initialize a ConfigCliParser object for option parsing.""" """Initialize a ConfigCliParser object for option parsing."""
@ -1590,7 +1624,8 @@ class ConfigOpts(collections.Mapping):
prog=None, prog=None,
version=None, version=None,
usage=None, usage=None,
default_config_files=None): default_config_files=None,
validate_default_values=False):
"""Parse command line arguments and config files. """Parse command line arguments and config files.
Calling a ConfigOpts object causes the supplied command line arguments Calling a ConfigOpts object causes the supplied command line arguments
@ -1613,13 +1648,15 @@ class ConfigOpts(collections.Mapping):
:param version: the program version (for --version) :param version: the program version (for --version)
:param usage: a usage string (%prog will be expanded) :param usage: a usage string (%prog will be expanded)
:param default_config_files: config files to use by default :param default_config_files: config files to use by default
:param validate_default_values: whether to validate the default values
:returns: the list of arguments left over after parsing options :returns: the list of arguments left over after parsing options
:raises: SystemExit, ConfigFilesNotFoundError, ConfigFileParseError, :raises: SystemExit, ConfigFilesNotFoundError, ConfigFileParseError,
RequiredOptError, DuplicateOptError RequiredOptError, DuplicateOptError
""" """
self.clear() self.clear()
self._validate_default_values = validate_default_values
prog, default_config_files = self._pre_setup(project, prog, default_config_files = self._pre_setup(project,
prog, prog,
version, version,
@ -1679,6 +1716,7 @@ class ConfigOpts(collections.Mapping):
self._args = None self._args = None
self._oparser = None self._oparser = None
self._namespace = None self._namespace = None
self._validate_default_values = False
self.unregister_opts(self._config_opts) self.unregister_opts(self._config_opts)
for group in self._groups.values(): for group in self._groups.values():
group._clear() group._clear()
@ -2030,6 +2068,15 @@ class ConfigOpts(collections.Mapping):
if 'default' in info: if 'default' in info:
return info['default'] return info['default']
if self._validate_default_values:
if opt._validate_value and opt.default is not None:
try:
opt._validate_value(opt.default)
except ValueError as e:
raise ConfigFileValueError(
"Default value for option %s is not valid: %s"
% (opt.name, str(e)))
return opt.default return opt.default
def _substitute(self, value): def _substitute(self, value):

View File

@ -88,7 +88,8 @@ class BaseTestCase(utils.BaseTestCase):
prog='test', prog='test',
version='1.0', version='1.0',
usage='%(prog)s FOO BAR', usage='%(prog)s FOO BAR',
default_config_files=default_config_files) default_config_files=default_config_files,
validate_default_values=True)
def setUp(self): def setUp(self):
super(BaseTestCase, self).setUp() super(BaseTestCase, self).setUp()
@ -790,6 +791,18 @@ class ConfigFileOptsTestCase(BaseTestCase):
self.assertTrue(hasattr(self.conf, 'foo')) self.assertTrue(hasattr(self.conf, 'foo'))
self.assertEqual(self.conf.foo, 666) self.assertEqual(self.conf.foo, 666)
def test_conf_file_int_wrong_default(self):
self.conf.register_opt(cfg.IntOpt('foo', default='t666'))
paths = self.create_tempfiles([('test',
'[DEFAULT]\n')])
self.conf(['--config-file', paths[0]])
self.assertRaises(AttributeError,
getattr,
self.conf,
'foo')
def test_conf_file_int_value(self): def test_conf_file_int_value(self):
self.conf.register_opt(cfg.IntOpt('foo')) self.conf.register_opt(cfg.IntOpt('foo'))
@ -851,6 +864,18 @@ class ConfigFileOptsTestCase(BaseTestCase):
self.assertTrue(hasattr(self.conf, 'foo')) self.assertTrue(hasattr(self.conf, 'foo'))
self.assertEqual(self.conf.foo, 6.66) self.assertEqual(self.conf.foo, 6.66)
def test_conf_file_float_default_wrong_type(self):
self.conf.register_opt(cfg.FloatOpt('foo', default='foobar6.66'))
paths = self.create_tempfiles([('test',
'[DEFAULT]\n')])
self.conf(['--config-file', paths[0]])
self.assertRaises(AttributeError,
getattr,
self.conf,
'foo')
def test_conf_file_float_value(self): def test_conf_file_float_value(self):
self.conf.register_opt(cfg.FloatOpt('foo')) self.conf.register_opt(cfg.FloatOpt('foo'))
@ -2909,6 +2934,16 @@ class ChoicesTestCase(BaseTestCase):
self.assertTrue(hasattr(self.conf, 'foo')) self.assertTrue(hasattr(self.conf, 'foo'))
self.assertEqual(self.conf.foo, 'baaar') self.assertEqual(self.conf.foo, 'baaar')
def test_conf_file_choice_bad_default(self):
self.conf.register_cli_opt(cfg.StrOpt('foo',
choices=['baar', 'baaar'],
default='foobaz'))
self.conf([])
self.assertRaises(AttributeError,
getattr,
self.conf,
'foobaz')
class PrintHelpTestCase(utils.BaseTestCase): class PrintHelpTestCase(utils.BaseTestCase):