Merge "Start switching from gflags to optparse"
This commit is contained in:
361
nova/flags.py
361
nova/flags.py
@@ -3,6 +3,7 @@
|
||||
# Copyright 2010 United States Government as represented by the
|
||||
# Administrator of the National Aeronautics and Space Administration.
|
||||
# All Rights Reserved.
|
||||
# Copyright 2011 Red Hat, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
@@ -18,14 +19,14 @@
|
||||
|
||||
"""Command-line flag library.
|
||||
|
||||
Wraps gflags.
|
||||
Emulates gflags by wrapping optparse.
|
||||
|
||||
Package-level global flags are defined here, the rest are defined
|
||||
where they're used.
|
||||
The idea is to move to optparse eventually, and this wrapper is a
|
||||
stepping stone.
|
||||
|
||||
"""
|
||||
|
||||
import getopt
|
||||
import optparse
|
||||
import os
|
||||
import socket
|
||||
import string
|
||||
@@ -34,120 +35,165 @@ import sys
|
||||
import gflags
|
||||
|
||||
|
||||
class FlagValues(gflags.FlagValues):
|
||||
"""Extension of gflags.FlagValues that allows undefined and runtime flags.
|
||||
class FlagValues(object):
|
||||
class Flag:
|
||||
def __init__(self, name, value, update_default=None):
|
||||
self.name = name
|
||||
self.value = value
|
||||
self._update_default = update_default
|
||||
|
||||
Unknown flags will be ignored when parsing the command line, but the
|
||||
command line will be kept so that it can be replayed if new flags are
|
||||
defined after the initial parsing.
|
||||
def SetDefault(self, default):
|
||||
if self._update_default:
|
||||
self._update_default(self.name, default)
|
||||
|
||||
"""
|
||||
class ErrorCatcher:
|
||||
def __init__(self, orig_error):
|
||||
self.orig_error = orig_error
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._error_msg = None
|
||||
|
||||
def catch(self, msg):
|
||||
if ": --" in msg:
|
||||
self._error_msg = msg
|
||||
else:
|
||||
self.orig_error(msg)
|
||||
|
||||
def get_unknown_arg(self, args):
|
||||
if not self._error_msg:
|
||||
return None
|
||||
# Error message is e.g. "no such option: --runtime_answer"
|
||||
a = self._error_msg[self._error_msg.rindex(": --") + 2:]
|
||||
return filter(lambda i: i == a or i.startswith(a + "="), args)[0]
|
||||
|
||||
def __init__(self, extra_context=None):
|
||||
gflags.FlagValues.__init__(self)
|
||||
self.__dict__['__dirty'] = []
|
||||
self.__dict__['__was_already_parsed'] = False
|
||||
self.__dict__['__stored_argv'] = []
|
||||
self.__dict__['__extra_context'] = extra_context
|
||||
self._parser = optparse.OptionParser()
|
||||
self._parser.disable_interspersed_args()
|
||||
self._extra_context = extra_context
|
||||
self.Reset()
|
||||
|
||||
def _parse(self):
|
||||
if not self._values is None:
|
||||
return
|
||||
|
||||
args = gflags.FlagValues().ReadFlagsFromFiles(self._args)
|
||||
|
||||
values = extra = None
|
||||
|
||||
#
|
||||
# This horrendous hack allows us to stop optparse
|
||||
# exiting when it encounters an unknown option
|
||||
#
|
||||
error_catcher = self.ErrorCatcher(self._parser.error)
|
||||
self._parser.error = error_catcher.catch
|
||||
try:
|
||||
while True:
|
||||
error_catcher.reset()
|
||||
|
||||
(values, extra) = self._parser.parse_args(args)
|
||||
|
||||
unknown = error_catcher.get_unknown_arg(args)
|
||||
if not unknown:
|
||||
break
|
||||
|
||||
args.remove(unknown)
|
||||
finally:
|
||||
self._parser.error = error_catcher.orig_error
|
||||
|
||||
(self._values, self._extra) = (values, extra)
|
||||
|
||||
def __call__(self, argv):
|
||||
# We're doing some hacky stuff here so that we don't have to copy
|
||||
# out all the code of the original verbatim and then tweak a few lines.
|
||||
# We're hijacking the output of getopt so we can still return the
|
||||
# leftover args at the end
|
||||
sneaky_unparsed_args = {"value": None}
|
||||
original_argv = list(argv)
|
||||
|
||||
if self.IsGnuGetOpt():
|
||||
orig_getopt = getattr(getopt, 'gnu_getopt')
|
||||
orig_name = 'gnu_getopt'
|
||||
else:
|
||||
orig_getopt = getattr(getopt, 'getopt')
|
||||
orig_name = 'getopt'
|
||||
|
||||
def _sneaky(*args, **kw):
|
||||
optlist, unparsed_args = orig_getopt(*args, **kw)
|
||||
sneaky_unparsed_args['value'] = unparsed_args
|
||||
return optlist, unparsed_args
|
||||
|
||||
try:
|
||||
setattr(getopt, orig_name, _sneaky)
|
||||
args = gflags.FlagValues.__call__(self, argv)
|
||||
except gflags.UnrecognizedFlagError:
|
||||
# Undefined args were found, for now we don't care so just
|
||||
# act like everything went well
|
||||
# (these three lines are copied pretty much verbatim from the end
|
||||
# of the __call__ function we are wrapping)
|
||||
unparsed_args = sneaky_unparsed_args['value']
|
||||
if unparsed_args:
|
||||
if self.IsGnuGetOpt():
|
||||
args = argv[:1] + unparsed_args
|
||||
else:
|
||||
args = argv[:1] + original_argv[-len(unparsed_args):]
|
||||
else:
|
||||
args = argv[:1]
|
||||
finally:
|
||||
setattr(getopt, orig_name, orig_getopt)
|
||||
|
||||
# Store the arguments for later, we'll need them for new flags
|
||||
# added at runtime
|
||||
self.__dict__['__stored_argv'] = original_argv
|
||||
self.__dict__['__was_already_parsed'] = True
|
||||
self.ClearDirty()
|
||||
return args
|
||||
|
||||
def Reset(self):
|
||||
gflags.FlagValues.Reset(self)
|
||||
self.__dict__['__dirty'] = []
|
||||
self.__dict__['__was_already_parsed'] = False
|
||||
self.__dict__['__stored_argv'] = []
|
||||
|
||||
def SetDirty(self, name):
|
||||
"""Mark a flag as dirty so that accessing it will case a reparse."""
|
||||
self.__dict__['__dirty'].append(name)
|
||||
|
||||
def IsDirty(self, name):
|
||||
return name in self.__dict__['__dirty']
|
||||
|
||||
def ClearDirty(self):
|
||||
self.__dict__['__dirty'] = []
|
||||
|
||||
def WasAlreadyParsed(self):
|
||||
return self.__dict__['__was_already_parsed']
|
||||
|
||||
def ParseNewFlags(self):
|
||||
if '__stored_argv' not in self.__dict__:
|
||||
return
|
||||
new_flags = FlagValues(self)
|
||||
for k in self.FlagDict().iterkeys():
|
||||
new_flags[k] = gflags.FlagValues.__getitem__(self, k)
|
||||
|
||||
new_flags.Reset()
|
||||
new_flags(self.__dict__['__stored_argv'])
|
||||
for k in new_flags.FlagDict().iterkeys():
|
||||
setattr(self, k, getattr(new_flags, k))
|
||||
self.ClearDirty()
|
||||
|
||||
def __setitem__(self, name, flag):
|
||||
gflags.FlagValues.__setitem__(self, name, flag)
|
||||
if self.WasAlreadyParsed():
|
||||
self.SetDirty(name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
if self.IsDirty(name):
|
||||
self.ParseNewFlags()
|
||||
return gflags.FlagValues.__getitem__(self, name)
|
||||
self._args = argv[1:]
|
||||
self._values = None
|
||||
self._parse()
|
||||
return [argv[0]] + self._extra
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self.IsDirty(name):
|
||||
self.ParseNewFlags()
|
||||
val = gflags.FlagValues.__getattr__(self, name)
|
||||
self._parse()
|
||||
val = getattr(self._values, name)
|
||||
if type(val) is str:
|
||||
tmpl = string.Template(val)
|
||||
context = [self, self.__dict__['__extra_context']]
|
||||
context = [self, self._extra_context]
|
||||
return tmpl.substitute(StrWrapper(context))
|
||||
return val
|
||||
|
||||
def get(self, name, default):
|
||||
value = getattr(self, name)
|
||||
if value is not None: # value might be '0' or ""
|
||||
return value
|
||||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, name):
|
||||
self._parse()
|
||||
return hasattr(self._values, name)
|
||||
|
||||
def _update_default(self, name, default):
|
||||
self._parser.set_default(name, default)
|
||||
self._values = None
|
||||
|
||||
def __iter__(self):
|
||||
return self.FlagValuesDict().iterkeys()
|
||||
|
||||
def __getitem__(self, name):
|
||||
self._parse()
|
||||
if not self.__contains__(name):
|
||||
return None
|
||||
return self.Flag(name, getattr(self, name), self._update_default)
|
||||
|
||||
def Reset(self):
|
||||
self._args = []
|
||||
self._values = None
|
||||
self._extra = None
|
||||
|
||||
def ParseNewFlags(self):
|
||||
pass
|
||||
|
||||
def FlagValuesDict(self):
|
||||
ret = {}
|
||||
for opt in self._parser.option_list:
|
||||
if opt.dest:
|
||||
ret[opt.dest] = getattr(self, opt.dest)
|
||||
return ret
|
||||
|
||||
def _add_option(self, name, default, help, prefix='--', **kwargs):
|
||||
prefixed_name = prefix + name
|
||||
for opt in self._parser.option_list:
|
||||
if prefixed_name == opt.get_opt_string():
|
||||
return
|
||||
self._parser.add_option(prefixed_name, dest=name,
|
||||
default=default, help=help, **kwargs)
|
||||
self._values = None
|
||||
|
||||
def define_string(self, name, default, help):
|
||||
self._add_option(name, default, help)
|
||||
|
||||
def define_integer(self, name, default, help):
|
||||
self._add_option(name, default, help, type='int')
|
||||
|
||||
def define_float(self, name, default, help):
|
||||
self._add_option(name, default, help, type='float')
|
||||
|
||||
def define_bool(self, name, default, help):
|
||||
#
|
||||
# FIXME(markmc): this doesn't support --boolflag=true/false/t/f/1/0
|
||||
#
|
||||
self._add_option(name, default, help, action='store_true')
|
||||
self._add_option(name, default, help,
|
||||
prefix="--no", action='store_false')
|
||||
|
||||
def define_list(self, name, default, help):
|
||||
def parse_list(option, opt, value, parser):
|
||||
setattr(self._parser.values, name, value.split(','))
|
||||
self._add_option(name, default, help, type='string',
|
||||
action='callback', callback=parse_list)
|
||||
|
||||
def define_multistring(self, name, default, help):
|
||||
self._add_option(name, default, help, action='append')
|
||||
|
||||
FLAGS = FlagValues()
|
||||
|
||||
|
||||
class StrWrapper(object):
|
||||
"""Wrapper around FlagValues objects.
|
||||
@@ -167,85 +213,60 @@ class StrWrapper(object):
|
||||
raise KeyError(name)
|
||||
|
||||
|
||||
# Copied from gflags with small mods to get the naming correct.
|
||||
# Originally gflags checks for the first module that is not gflags that is
|
||||
# in the call chain, we want to check for the first module that is not gflags
|
||||
# and not this module.
|
||||
def _GetCallingModule():
|
||||
"""Returns the name of the module that's calling into this module.
|
||||
|
||||
We generally use this function to get the name of the module calling a
|
||||
DEFINE_foo... function.
|
||||
|
||||
"""
|
||||
# Walk down the stack to find the first globals dict that's not ours.
|
||||
for depth in range(1, sys.getrecursionlimit()):
|
||||
if not sys._getframe(depth).f_globals is globals():
|
||||
module_name = __GetModuleName(sys._getframe(depth).f_globals)
|
||||
if module_name == 'gflags':
|
||||
continue
|
||||
if module_name is not None:
|
||||
return module_name
|
||||
raise AssertionError("No module was found")
|
||||
def DEFINE_string(name, default, help, flag_values=FLAGS):
|
||||
flag_values.define_string(name, default, help)
|
||||
|
||||
|
||||
# Copied from gflags because it is a private function
|
||||
def __GetModuleName(globals_dict):
|
||||
"""Given a globals dict, returns the name of the module that defines it.
|
||||
|
||||
Args:
|
||||
globals_dict: A dictionary that should correspond to an environment
|
||||
providing the values of the globals.
|
||||
|
||||
Returns:
|
||||
A string (the name of the module) or None (if the module could not
|
||||
be identified.
|
||||
|
||||
"""
|
||||
for name, module in sys.modules.iteritems():
|
||||
if getattr(module, '__dict__', None) is globals_dict:
|
||||
if name == '__main__':
|
||||
return sys.argv[0]
|
||||
return name
|
||||
return None
|
||||
def DEFINE_integer(name, default, help, lower_bound=None, flag_values=FLAGS):
|
||||
# FIXME(markmc): ignoring lower_bound
|
||||
flag_values.define_integer(name, default, help)
|
||||
|
||||
|
||||
def _wrapper(func):
|
||||
def _wrapped(*args, **kw):
|
||||
kw.setdefault('flag_values', FLAGS)
|
||||
func(*args, **kw)
|
||||
_wrapped.func_name = func.func_name
|
||||
return _wrapped
|
||||
def DEFINE_bool(name, default, help, flag_values=FLAGS):
|
||||
flag_values.define_bool(name, default, help)
|
||||
|
||||
|
||||
FLAGS = FlagValues()
|
||||
gflags.FLAGS = FLAGS
|
||||
gflags._GetCallingModule = _GetCallingModule
|
||||
def DEFINE_boolean(name, default, help, flag_values=FLAGS):
|
||||
DEFINE_bool(name, default, help, flag_values)
|
||||
|
||||
|
||||
DEFINE = _wrapper(gflags.DEFINE)
|
||||
DEFINE_string = _wrapper(gflags.DEFINE_string)
|
||||
DEFINE_integer = _wrapper(gflags.DEFINE_integer)
|
||||
DEFINE_bool = _wrapper(gflags.DEFINE_bool)
|
||||
DEFINE_boolean = _wrapper(gflags.DEFINE_boolean)
|
||||
DEFINE_float = _wrapper(gflags.DEFINE_float)
|
||||
DEFINE_enum = _wrapper(gflags.DEFINE_enum)
|
||||
DEFINE_list = _wrapper(gflags.DEFINE_list)
|
||||
DEFINE_spaceseplist = _wrapper(gflags.DEFINE_spaceseplist)
|
||||
DEFINE_multistring = _wrapper(gflags.DEFINE_multistring)
|
||||
DEFINE_multi_int = _wrapper(gflags.DEFINE_multi_int)
|
||||
DEFINE_flag = _wrapper(gflags.DEFINE_flag)
|
||||
HelpFlag = gflags.HelpFlag
|
||||
HelpshortFlag = gflags.HelpshortFlag
|
||||
HelpXMLFlag = gflags.HelpXMLFlag
|
||||
def DEFINE_list(name, default, help, flag_values=FLAGS):
|
||||
flag_values.define_list(name, default, help)
|
||||
|
||||
|
||||
def DEFINE_float(name, default, help, flag_values=FLAGS):
|
||||
flag_values.define_float(name, default, help)
|
||||
|
||||
|
||||
def DEFINE_multistring(name, default, help, flag_values=FLAGS):
|
||||
flag_values.define_multistring(name, default, help)
|
||||
|
||||
|
||||
class UnrecognizedFlag(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def DECLARE(name, module_string, flag_values=FLAGS):
|
||||
if module_string not in sys.modules:
|
||||
__import__(module_string, globals(), locals())
|
||||
if name not in flag_values:
|
||||
raise gflags.UnrecognizedFlag(
|
||||
"%s not defined by %s" % (name, module_string))
|
||||
raise UnrecognizedFlag('%s not defined by %s' % (name, module_string))
|
||||
|
||||
|
||||
def DEFINE_flag(flag):
|
||||
pass
|
||||
|
||||
|
||||
class HelpFlag:
|
||||
pass
|
||||
|
||||
|
||||
class HelpshortFlag:
|
||||
pass
|
||||
|
||||
|
||||
class HelpXMLFlag:
|
||||
pass
|
||||
|
||||
|
||||
def _get_my_ip():
|
||||
|
@@ -3,6 +3,7 @@
|
||||
# Copyright 2010 United States Government as represented by the
|
||||
# Administrator of the National Aeronautics and Space Administration.
|
||||
# All Rights Reserved.
|
||||
# Copyright 2011 Red Hat, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
@@ -16,6 +17,10 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import exceptions
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from nova import exception
|
||||
from nova import flags
|
||||
from nova import test
|
||||
@@ -64,6 +69,37 @@ class FlagsTestCase(test.TestCase):
|
||||
self.assertEqual(self.FLAGS.false, True)
|
||||
self.assertEqual(self.FLAGS.true, False)
|
||||
|
||||
def test_define_float(self):
|
||||
flags.DEFINE_float('float', 6.66, 'desc', flag_values=self.FLAGS)
|
||||
self.assertEqual(self.FLAGS.float, 6.66)
|
||||
|
||||
def test_define_multistring(self):
|
||||
flags.DEFINE_multistring('multi', [], 'desc', flag_values=self.FLAGS)
|
||||
|
||||
argv = ['flags_test', '--multi', 'foo', '--multi', 'bar']
|
||||
self.FLAGS(argv)
|
||||
|
||||
self.assertEqual(self.FLAGS.multi, ['foo', 'bar'])
|
||||
|
||||
def test_define_list(self):
|
||||
flags.DEFINE_list('list', ['foo'], 'desc', flag_values=self.FLAGS)
|
||||
|
||||
self.assert_(self.FLAGS['list'])
|
||||
self.assertEqual(self.FLAGS.list, ['foo'])
|
||||
|
||||
argv = ['flags_test', '--list=a,b,c,d']
|
||||
self.FLAGS(argv)
|
||||
|
||||
self.assertEqual(self.FLAGS.list, ['a', 'b', 'c', 'd'])
|
||||
|
||||
def test_error(self):
|
||||
flags.DEFINE_integer('error', 1, 'desc', flag_values=self.FLAGS)
|
||||
|
||||
self.assertEqual(self.FLAGS.error, 1)
|
||||
|
||||
argv = ['flags_test', '--error=foo']
|
||||
self.assertRaises(exceptions.SystemExit, self.FLAGS, argv)
|
||||
|
||||
def test_declare(self):
|
||||
self.assert_('answer' not in self.global_FLAGS)
|
||||
flags.DECLARE('answer', 'nova.tests.declare_flags')
|
||||
@@ -76,6 +112,14 @@ class FlagsTestCase(test.TestCase):
|
||||
flags.DECLARE('answer', 'nova.tests.declare_flags')
|
||||
self.assertEqual(self.global_FLAGS.answer, 256)
|
||||
|
||||
def test_getopt_non_interspersed_args(self):
|
||||
self.assert_('runtime_answer' not in self.global_FLAGS)
|
||||
|
||||
argv = ['flags_test', 'extra_arg', '--runtime_answer=60']
|
||||
args = self.global_FLAGS(argv)
|
||||
self.assertEqual(len(args), 3)
|
||||
self.assertEqual(argv, args)
|
||||
|
||||
def test_runtime_and_unknown_flags(self):
|
||||
self.assert_('runtime_answer' not in self.global_FLAGS)
|
||||
|
||||
@@ -114,3 +158,49 @@ class FlagsTestCase(test.TestCase):
|
||||
self.assertEqual(FLAGS.flags_unittest, 'foo')
|
||||
FLAGS.flags_unittest = 'bar'
|
||||
self.assertEqual(FLAGS.flags_unittest, 'bar')
|
||||
|
||||
def test_flag_overrides(self):
|
||||
self.assertEqual(FLAGS.flags_unittest, 'foo')
|
||||
self.flags(flags_unittest='bar')
|
||||
self.assertEqual(FLAGS.flags_unittest, 'bar')
|
||||
self.assertEqual(FLAGS['flags_unittest'].value, 'bar')
|
||||
self.assertEqual(FLAGS.FlagValuesDict()['flags_unittest'], 'bar')
|
||||
self.reset_flags()
|
||||
self.assertEqual(FLAGS.flags_unittest, 'foo')
|
||||
self.assertEqual(FLAGS['flags_unittest'].value, 'foo')
|
||||
self.assertEqual(FLAGS.FlagValuesDict()['flags_unittest'], 'foo')
|
||||
|
||||
def test_flagfile(self):
|
||||
flags.DEFINE_string('string', 'default', 'desc',
|
||||
flag_values=self.FLAGS)
|
||||
flags.DEFINE_integer('int', 1, 'desc', flag_values=self.FLAGS)
|
||||
flags.DEFINE_bool('false', False, 'desc', flag_values=self.FLAGS)
|
||||
flags.DEFINE_bool('true', True, 'desc', flag_values=self.FLAGS)
|
||||
|
||||
(fd, path) = tempfile.mkstemp(prefix='nova', suffix='.flags')
|
||||
|
||||
try:
|
||||
os.write(fd, '--string=foo\n--int=2\n--false\n--notrue\n')
|
||||
os.close(fd)
|
||||
|
||||
self.FLAGS(['flags_test', '--flagfile=' + path])
|
||||
|
||||
self.assertEqual(self.FLAGS.string, 'foo')
|
||||
self.assertEqual(self.FLAGS.int, 2)
|
||||
self.assertEqual(self.FLAGS.false, True)
|
||||
self.assertEqual(self.FLAGS.true, False)
|
||||
finally:
|
||||
os.remove(path)
|
||||
|
||||
def test_defaults(self):
|
||||
flags.DEFINE_string('foo', 'bar', 'help', flag_values=self.FLAGS)
|
||||
self.assertEqual(self.FLAGS.foo, 'bar')
|
||||
|
||||
self.FLAGS['foo'].SetDefault('blaa')
|
||||
self.assertEqual(self.FLAGS.foo, 'blaa')
|
||||
|
||||
def test_templated_values(self):
|
||||
flags.DEFINE_string('foo', 'foo', 'help', flag_values=self.FLAGS)
|
||||
flags.DEFINE_string('bar', 'bar', 'help', flag_values=self.FLAGS)
|
||||
flags.DEFINE_string('blaa', '$foo$bar', 'help', flag_values=self.FLAGS)
|
||||
self.assertEqual(self.FLAGS.blaa, 'foobar')
|
||||
|
Reference in New Issue
Block a user