Merge "Start switching from gflags to optparse"

This commit is contained in:
Jenkins
2011-10-19 20:17:02 +00:00
committed by Gerrit Code Review
2 changed files with 281 additions and 170 deletions

View File

@@ -3,6 +3,7 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
# Copyright 2011 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@@ -18,14 +19,14 @@
"""Command-line flag library.
Wraps gflags.
Emulates gflags by wrapping optparse.
Package-level global flags are defined here, the rest are defined
where they're used.
The idea is to move to optparse eventually, and this wrapper is a
stepping stone.
"""
import getopt
import optparse
import os
import socket
import string
@@ -34,120 +35,165 @@ import sys
import gflags
class FlagValues(gflags.FlagValues):
"""Extension of gflags.FlagValues that allows undefined and runtime flags.
class FlagValues(object):
class Flag:
def __init__(self, name, value, update_default=None):
self.name = name
self.value = value
self._update_default = update_default
Unknown flags will be ignored when parsing the command line, but the
command line will be kept so that it can be replayed if new flags are
defined after the initial parsing.
def SetDefault(self, default):
if self._update_default:
self._update_default(self.name, default)
"""
class ErrorCatcher:
def __init__(self, orig_error):
self.orig_error = orig_error
self.reset()
def reset(self):
self._error_msg = None
def catch(self, msg):
if ": --" in msg:
self._error_msg = msg
else:
self.orig_error(msg)
def get_unknown_arg(self, args):
if not self._error_msg:
return None
# Error message is e.g. "no such option: --runtime_answer"
a = self._error_msg[self._error_msg.rindex(": --") + 2:]
return filter(lambda i: i == a or i.startswith(a + "="), args)[0]
def __init__(self, extra_context=None):
gflags.FlagValues.__init__(self)
self.__dict__['__dirty'] = []
self.__dict__['__was_already_parsed'] = False
self.__dict__['__stored_argv'] = []
self.__dict__['__extra_context'] = extra_context
self._parser = optparse.OptionParser()
self._parser.disable_interspersed_args()
self._extra_context = extra_context
self.Reset()
def _parse(self):
if not self._values is None:
return
args = gflags.FlagValues().ReadFlagsFromFiles(self._args)
values = extra = None
#
# This horrendous hack allows us to stop optparse
# exiting when it encounters an unknown option
#
error_catcher = self.ErrorCatcher(self._parser.error)
self._parser.error = error_catcher.catch
try:
while True:
error_catcher.reset()
(values, extra) = self._parser.parse_args(args)
unknown = error_catcher.get_unknown_arg(args)
if not unknown:
break
args.remove(unknown)
finally:
self._parser.error = error_catcher.orig_error
(self._values, self._extra) = (values, extra)
def __call__(self, argv):
# We're doing some hacky stuff here so that we don't have to copy
# out all the code of the original verbatim and then tweak a few lines.
# We're hijacking the output of getopt so we can still return the
# leftover args at the end
sneaky_unparsed_args = {"value": None}
original_argv = list(argv)
if self.IsGnuGetOpt():
orig_getopt = getattr(getopt, 'gnu_getopt')
orig_name = 'gnu_getopt'
else:
orig_getopt = getattr(getopt, 'getopt')
orig_name = 'getopt'
def _sneaky(*args, **kw):
optlist, unparsed_args = orig_getopt(*args, **kw)
sneaky_unparsed_args['value'] = unparsed_args
return optlist, unparsed_args
try:
setattr(getopt, orig_name, _sneaky)
args = gflags.FlagValues.__call__(self, argv)
except gflags.UnrecognizedFlagError:
# Undefined args were found, for now we don't care so just
# act like everything went well
# (these three lines are copied pretty much verbatim from the end
# of the __call__ function we are wrapping)
unparsed_args = sneaky_unparsed_args['value']
if unparsed_args:
if self.IsGnuGetOpt():
args = argv[:1] + unparsed_args
else:
args = argv[:1] + original_argv[-len(unparsed_args):]
else:
args = argv[:1]
finally:
setattr(getopt, orig_name, orig_getopt)
# Store the arguments for later, we'll need them for new flags
# added at runtime
self.__dict__['__stored_argv'] = original_argv
self.__dict__['__was_already_parsed'] = True
self.ClearDirty()
return args
def Reset(self):
gflags.FlagValues.Reset(self)
self.__dict__['__dirty'] = []
self.__dict__['__was_already_parsed'] = False
self.__dict__['__stored_argv'] = []
def SetDirty(self, name):
"""Mark a flag as dirty so that accessing it will case a reparse."""
self.__dict__['__dirty'].append(name)
def IsDirty(self, name):
return name in self.__dict__['__dirty']
def ClearDirty(self):
self.__dict__['__dirty'] = []
def WasAlreadyParsed(self):
return self.__dict__['__was_already_parsed']
def ParseNewFlags(self):
if '__stored_argv' not in self.__dict__:
return
new_flags = FlagValues(self)
for k in self.FlagDict().iterkeys():
new_flags[k] = gflags.FlagValues.__getitem__(self, k)
new_flags.Reset()
new_flags(self.__dict__['__stored_argv'])
for k in new_flags.FlagDict().iterkeys():
setattr(self, k, getattr(new_flags, k))
self.ClearDirty()
def __setitem__(self, name, flag):
gflags.FlagValues.__setitem__(self, name, flag)
if self.WasAlreadyParsed():
self.SetDirty(name)
def __getitem__(self, name):
if self.IsDirty(name):
self.ParseNewFlags()
return gflags.FlagValues.__getitem__(self, name)
self._args = argv[1:]
self._values = None
self._parse()
return [argv[0]] + self._extra
def __getattr__(self, name):
if self.IsDirty(name):
self.ParseNewFlags()
val = gflags.FlagValues.__getattr__(self, name)
self._parse()
val = getattr(self._values, name)
if type(val) is str:
tmpl = string.Template(val)
context = [self, self.__dict__['__extra_context']]
context = [self, self._extra_context]
return tmpl.substitute(StrWrapper(context))
return val
def get(self, name, default):
value = getattr(self, name)
if value is not None: # value might be '0' or ""
return value
else:
return default
def __contains__(self, name):
self._parse()
return hasattr(self._values, name)
def _update_default(self, name, default):
self._parser.set_default(name, default)
self._values = None
def __iter__(self):
return self.FlagValuesDict().iterkeys()
def __getitem__(self, name):
self._parse()
if not self.__contains__(name):
return None
return self.Flag(name, getattr(self, name), self._update_default)
def Reset(self):
self._args = []
self._values = None
self._extra = None
def ParseNewFlags(self):
pass
def FlagValuesDict(self):
ret = {}
for opt in self._parser.option_list:
if opt.dest:
ret[opt.dest] = getattr(self, opt.dest)
return ret
def _add_option(self, name, default, help, prefix='--', **kwargs):
prefixed_name = prefix + name
for opt in self._parser.option_list:
if prefixed_name == opt.get_opt_string():
return
self._parser.add_option(prefixed_name, dest=name,
default=default, help=help, **kwargs)
self._values = None
def define_string(self, name, default, help):
self._add_option(name, default, help)
def define_integer(self, name, default, help):
self._add_option(name, default, help, type='int')
def define_float(self, name, default, help):
self._add_option(name, default, help, type='float')
def define_bool(self, name, default, help):
#
# FIXME(markmc): this doesn't support --boolflag=true/false/t/f/1/0
#
self._add_option(name, default, help, action='store_true')
self._add_option(name, default, help,
prefix="--no", action='store_false')
def define_list(self, name, default, help):
def parse_list(option, opt, value, parser):
setattr(self._parser.values, name, value.split(','))
self._add_option(name, default, help, type='string',
action='callback', callback=parse_list)
def define_multistring(self, name, default, help):
self._add_option(name, default, help, action='append')
FLAGS = FlagValues()
class StrWrapper(object):
"""Wrapper around FlagValues objects.
@@ -167,85 +213,60 @@ class StrWrapper(object):
raise KeyError(name)
# Copied from gflags with small mods to get the naming correct.
# Originally gflags checks for the first module that is not gflags that is
# in the call chain, we want to check for the first module that is not gflags
# and not this module.
def _GetCallingModule():
"""Returns the name of the module that's calling into this module.
We generally use this function to get the name of the module calling a
DEFINE_foo... function.
"""
# Walk down the stack to find the first globals dict that's not ours.
for depth in range(1, sys.getrecursionlimit()):
if not sys._getframe(depth).f_globals is globals():
module_name = __GetModuleName(sys._getframe(depth).f_globals)
if module_name == 'gflags':
continue
if module_name is not None:
return module_name
raise AssertionError("No module was found")
def DEFINE_string(name, default, help, flag_values=FLAGS):
flag_values.define_string(name, default, help)
# Copied from gflags because it is a private function
def __GetModuleName(globals_dict):
"""Given a globals dict, returns the name of the module that defines it.
Args:
globals_dict: A dictionary that should correspond to an environment
providing the values of the globals.
Returns:
A string (the name of the module) or None (if the module could not
be identified.
"""
for name, module in sys.modules.iteritems():
if getattr(module, '__dict__', None) is globals_dict:
if name == '__main__':
return sys.argv[0]
return name
return None
def DEFINE_integer(name, default, help, lower_bound=None, flag_values=FLAGS):
# FIXME(markmc): ignoring lower_bound
flag_values.define_integer(name, default, help)
def _wrapper(func):
def _wrapped(*args, **kw):
kw.setdefault('flag_values', FLAGS)
func(*args, **kw)
_wrapped.func_name = func.func_name
return _wrapped
def DEFINE_bool(name, default, help, flag_values=FLAGS):
flag_values.define_bool(name, default, help)
FLAGS = FlagValues()
gflags.FLAGS = FLAGS
gflags._GetCallingModule = _GetCallingModule
def DEFINE_boolean(name, default, help, flag_values=FLAGS):
DEFINE_bool(name, default, help, flag_values)
DEFINE = _wrapper(gflags.DEFINE)
DEFINE_string = _wrapper(gflags.DEFINE_string)
DEFINE_integer = _wrapper(gflags.DEFINE_integer)
DEFINE_bool = _wrapper(gflags.DEFINE_bool)
DEFINE_boolean = _wrapper(gflags.DEFINE_boolean)
DEFINE_float = _wrapper(gflags.DEFINE_float)
DEFINE_enum = _wrapper(gflags.DEFINE_enum)
DEFINE_list = _wrapper(gflags.DEFINE_list)
DEFINE_spaceseplist = _wrapper(gflags.DEFINE_spaceseplist)
DEFINE_multistring = _wrapper(gflags.DEFINE_multistring)
DEFINE_multi_int = _wrapper(gflags.DEFINE_multi_int)
DEFINE_flag = _wrapper(gflags.DEFINE_flag)
HelpFlag = gflags.HelpFlag
HelpshortFlag = gflags.HelpshortFlag
HelpXMLFlag = gflags.HelpXMLFlag
def DEFINE_list(name, default, help, flag_values=FLAGS):
flag_values.define_list(name, default, help)
def DEFINE_float(name, default, help, flag_values=FLAGS):
flag_values.define_float(name, default, help)
def DEFINE_multistring(name, default, help, flag_values=FLAGS):
flag_values.define_multistring(name, default, help)
class UnrecognizedFlag(Exception):
pass
def DECLARE(name, module_string, flag_values=FLAGS):
if module_string not in sys.modules:
__import__(module_string, globals(), locals())
if name not in flag_values:
raise gflags.UnrecognizedFlag(
"%s not defined by %s" % (name, module_string))
raise UnrecognizedFlag('%s not defined by %s' % (name, module_string))
def DEFINE_flag(flag):
pass
class HelpFlag:
pass
class HelpshortFlag:
pass
class HelpXMLFlag:
pass
def _get_my_ip():

View File

@@ -3,6 +3,7 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
# Copyright 2011 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@@ -16,6 +17,10 @@
# License for the specific language governing permissions and limitations
# under the License.
import exceptions
import os
import tempfile
from nova import exception
from nova import flags
from nova import test
@@ -64,6 +69,37 @@ class FlagsTestCase(test.TestCase):
self.assertEqual(self.FLAGS.false, True)
self.assertEqual(self.FLAGS.true, False)
def test_define_float(self):
flags.DEFINE_float('float', 6.66, 'desc', flag_values=self.FLAGS)
self.assertEqual(self.FLAGS.float, 6.66)
def test_define_multistring(self):
flags.DEFINE_multistring('multi', [], 'desc', flag_values=self.FLAGS)
argv = ['flags_test', '--multi', 'foo', '--multi', 'bar']
self.FLAGS(argv)
self.assertEqual(self.FLAGS.multi, ['foo', 'bar'])
def test_define_list(self):
flags.DEFINE_list('list', ['foo'], 'desc', flag_values=self.FLAGS)
self.assert_(self.FLAGS['list'])
self.assertEqual(self.FLAGS.list, ['foo'])
argv = ['flags_test', '--list=a,b,c,d']
self.FLAGS(argv)
self.assertEqual(self.FLAGS.list, ['a', 'b', 'c', 'd'])
def test_error(self):
flags.DEFINE_integer('error', 1, 'desc', flag_values=self.FLAGS)
self.assertEqual(self.FLAGS.error, 1)
argv = ['flags_test', '--error=foo']
self.assertRaises(exceptions.SystemExit, self.FLAGS, argv)
def test_declare(self):
self.assert_('answer' not in self.global_FLAGS)
flags.DECLARE('answer', 'nova.tests.declare_flags')
@@ -76,6 +112,14 @@ class FlagsTestCase(test.TestCase):
flags.DECLARE('answer', 'nova.tests.declare_flags')
self.assertEqual(self.global_FLAGS.answer, 256)
def test_getopt_non_interspersed_args(self):
self.assert_('runtime_answer' not in self.global_FLAGS)
argv = ['flags_test', 'extra_arg', '--runtime_answer=60']
args = self.global_FLAGS(argv)
self.assertEqual(len(args), 3)
self.assertEqual(argv, args)
def test_runtime_and_unknown_flags(self):
self.assert_('runtime_answer' not in self.global_FLAGS)
@@ -114,3 +158,49 @@ class FlagsTestCase(test.TestCase):
self.assertEqual(FLAGS.flags_unittest, 'foo')
FLAGS.flags_unittest = 'bar'
self.assertEqual(FLAGS.flags_unittest, 'bar')
def test_flag_overrides(self):
self.assertEqual(FLAGS.flags_unittest, 'foo')
self.flags(flags_unittest='bar')
self.assertEqual(FLAGS.flags_unittest, 'bar')
self.assertEqual(FLAGS['flags_unittest'].value, 'bar')
self.assertEqual(FLAGS.FlagValuesDict()['flags_unittest'], 'bar')
self.reset_flags()
self.assertEqual(FLAGS.flags_unittest, 'foo')
self.assertEqual(FLAGS['flags_unittest'].value, 'foo')
self.assertEqual(FLAGS.FlagValuesDict()['flags_unittest'], 'foo')
def test_flagfile(self):
flags.DEFINE_string('string', 'default', 'desc',
flag_values=self.FLAGS)
flags.DEFINE_integer('int', 1, 'desc', flag_values=self.FLAGS)
flags.DEFINE_bool('false', False, 'desc', flag_values=self.FLAGS)
flags.DEFINE_bool('true', True, 'desc', flag_values=self.FLAGS)
(fd, path) = tempfile.mkstemp(prefix='nova', suffix='.flags')
try:
os.write(fd, '--string=foo\n--int=2\n--false\n--notrue\n')
os.close(fd)
self.FLAGS(['flags_test', '--flagfile=' + path])
self.assertEqual(self.FLAGS.string, 'foo')
self.assertEqual(self.FLAGS.int, 2)
self.assertEqual(self.FLAGS.false, True)
self.assertEqual(self.FLAGS.true, False)
finally:
os.remove(path)
def test_defaults(self):
flags.DEFINE_string('foo', 'bar', 'help', flag_values=self.FLAGS)
self.assertEqual(self.FLAGS.foo, 'bar')
self.FLAGS['foo'].SetDefault('blaa')
self.assertEqual(self.FLAGS.foo, 'blaa')
def test_templated_values(self):
flags.DEFINE_string('foo', 'foo', 'help', flag_values=self.FLAGS)
flags.DEFINE_string('bar', 'bar', 'help', flag_values=self.FLAGS)
flags.DEFINE_string('blaa', '$foo$bar', 'help', flag_values=self.FLAGS)
self.assertEqual(self.FLAGS.blaa, 'foobar')