Merge "Start switching from gflags to optparse"

This commit is contained in:
Jenkins
2011-10-19 20:17:02 +00:00
committed by Gerrit Code Review
2 changed files with 281 additions and 170 deletions

View File

@@ -3,6 +3,7 @@
# Copyright 2010 United States Government as represented by the # Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration. # Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved. # All Rights Reserved.
# Copyright 2011 Red Hat, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@@ -18,14 +19,14 @@
"""Command-line flag library. """Command-line flag library.
Wraps gflags. Emulates gflags by wrapping optparse.
Package-level global flags are defined here, the rest are defined The idea is to move to optparse eventually, and this wrapper is a
where they're used. stepping stone.
""" """
import getopt import optparse
import os import os
import socket import socket
import string import string
@@ -34,120 +35,165 @@ import sys
import gflags import gflags
class FlagValues(gflags.FlagValues): class FlagValues(object):
"""Extension of gflags.FlagValues that allows undefined and runtime flags. class Flag:
def __init__(self, name, value, update_default=None):
self.name = name
self.value = value
self._update_default = update_default
Unknown flags will be ignored when parsing the command line, but the def SetDefault(self, default):
command line will be kept so that it can be replayed if new flags are if self._update_default:
defined after the initial parsing. self._update_default(self.name, default)
""" class ErrorCatcher:
def __init__(self, orig_error):
self.orig_error = orig_error
self.reset()
def reset(self):
self._error_msg = None
def catch(self, msg):
if ": --" in msg:
self._error_msg = msg
else:
self.orig_error(msg)
def get_unknown_arg(self, args):
if not self._error_msg:
return None
# Error message is e.g. "no such option: --runtime_answer"
a = self._error_msg[self._error_msg.rindex(": --") + 2:]
return filter(lambda i: i == a or i.startswith(a + "="), args)[0]
def __init__(self, extra_context=None): def __init__(self, extra_context=None):
gflags.FlagValues.__init__(self) self._parser = optparse.OptionParser()
self.__dict__['__dirty'] = [] self._parser.disable_interspersed_args()
self.__dict__['__was_already_parsed'] = False self._extra_context = extra_context
self.__dict__['__stored_argv'] = [] self.Reset()
self.__dict__['__extra_context'] = extra_context
def _parse(self):
if not self._values is None:
return
args = gflags.FlagValues().ReadFlagsFromFiles(self._args)
values = extra = None
#
# This horrendous hack allows us to stop optparse
# exiting when it encounters an unknown option
#
error_catcher = self.ErrorCatcher(self._parser.error)
self._parser.error = error_catcher.catch
try:
while True:
error_catcher.reset()
(values, extra) = self._parser.parse_args(args)
unknown = error_catcher.get_unknown_arg(args)
if not unknown:
break
args.remove(unknown)
finally:
self._parser.error = error_catcher.orig_error
(self._values, self._extra) = (values, extra)
def __call__(self, argv): def __call__(self, argv):
# We're doing some hacky stuff here so that we don't have to copy self._args = argv[1:]
# out all the code of the original verbatim and then tweak a few lines. self._values = None
# We're hijacking the output of getopt so we can still return the self._parse()
# leftover args at the end return [argv[0]] + self._extra
sneaky_unparsed_args = {"value": None}
original_argv = list(argv)
if self.IsGnuGetOpt():
orig_getopt = getattr(getopt, 'gnu_getopt')
orig_name = 'gnu_getopt'
else:
orig_getopt = getattr(getopt, 'getopt')
orig_name = 'getopt'
def _sneaky(*args, **kw):
optlist, unparsed_args = orig_getopt(*args, **kw)
sneaky_unparsed_args['value'] = unparsed_args
return optlist, unparsed_args
try:
setattr(getopt, orig_name, _sneaky)
args = gflags.FlagValues.__call__(self, argv)
except gflags.UnrecognizedFlagError:
# Undefined args were found, for now we don't care so just
# act like everything went well
# (these three lines are copied pretty much verbatim from the end
# of the __call__ function we are wrapping)
unparsed_args = sneaky_unparsed_args['value']
if unparsed_args:
if self.IsGnuGetOpt():
args = argv[:1] + unparsed_args
else:
args = argv[:1] + original_argv[-len(unparsed_args):]
else:
args = argv[:1]
finally:
setattr(getopt, orig_name, orig_getopt)
# Store the arguments for later, we'll need them for new flags
# added at runtime
self.__dict__['__stored_argv'] = original_argv
self.__dict__['__was_already_parsed'] = True
self.ClearDirty()
return args
def Reset(self):
gflags.FlagValues.Reset(self)
self.__dict__['__dirty'] = []
self.__dict__['__was_already_parsed'] = False
self.__dict__['__stored_argv'] = []
def SetDirty(self, name):
"""Mark a flag as dirty so that accessing it will case a reparse."""
self.__dict__['__dirty'].append(name)
def IsDirty(self, name):
return name in self.__dict__['__dirty']
def ClearDirty(self):
self.__dict__['__dirty'] = []
def WasAlreadyParsed(self):
return self.__dict__['__was_already_parsed']
def ParseNewFlags(self):
if '__stored_argv' not in self.__dict__:
return
new_flags = FlagValues(self)
for k in self.FlagDict().iterkeys():
new_flags[k] = gflags.FlagValues.__getitem__(self, k)
new_flags.Reset()
new_flags(self.__dict__['__stored_argv'])
for k in new_flags.FlagDict().iterkeys():
setattr(self, k, getattr(new_flags, k))
self.ClearDirty()
def __setitem__(self, name, flag):
gflags.FlagValues.__setitem__(self, name, flag)
if self.WasAlreadyParsed():
self.SetDirty(name)
def __getitem__(self, name):
if self.IsDirty(name):
self.ParseNewFlags()
return gflags.FlagValues.__getitem__(self, name)
def __getattr__(self, name): def __getattr__(self, name):
if self.IsDirty(name): self._parse()
self.ParseNewFlags() val = getattr(self._values, name)
val = gflags.FlagValues.__getattr__(self, name)
if type(val) is str: if type(val) is str:
tmpl = string.Template(val) tmpl = string.Template(val)
context = [self, self.__dict__['__extra_context']] context = [self, self._extra_context]
return tmpl.substitute(StrWrapper(context)) return tmpl.substitute(StrWrapper(context))
return val return val
def get(self, name, default):
value = getattr(self, name)
if value is not None: # value might be '0' or ""
return value
else:
return default
def __contains__(self, name):
self._parse()
return hasattr(self._values, name)
def _update_default(self, name, default):
self._parser.set_default(name, default)
self._values = None
def __iter__(self):
return self.FlagValuesDict().iterkeys()
def __getitem__(self, name):
self._parse()
if not self.__contains__(name):
return None
return self.Flag(name, getattr(self, name), self._update_default)
def Reset(self):
self._args = []
self._values = None
self._extra = None
def ParseNewFlags(self):
pass
def FlagValuesDict(self):
ret = {}
for opt in self._parser.option_list:
if opt.dest:
ret[opt.dest] = getattr(self, opt.dest)
return ret
def _add_option(self, name, default, help, prefix='--', **kwargs):
prefixed_name = prefix + name
for opt in self._parser.option_list:
if prefixed_name == opt.get_opt_string():
return
self._parser.add_option(prefixed_name, dest=name,
default=default, help=help, **kwargs)
self._values = None
def define_string(self, name, default, help):
self._add_option(name, default, help)
def define_integer(self, name, default, help):
self._add_option(name, default, help, type='int')
def define_float(self, name, default, help):
self._add_option(name, default, help, type='float')
def define_bool(self, name, default, help):
#
# FIXME(markmc): this doesn't support --boolflag=true/false/t/f/1/0
#
self._add_option(name, default, help, action='store_true')
self._add_option(name, default, help,
prefix="--no", action='store_false')
def define_list(self, name, default, help):
def parse_list(option, opt, value, parser):
setattr(self._parser.values, name, value.split(','))
self._add_option(name, default, help, type='string',
action='callback', callback=parse_list)
def define_multistring(self, name, default, help):
self._add_option(name, default, help, action='append')
FLAGS = FlagValues()
class StrWrapper(object): class StrWrapper(object):
"""Wrapper around FlagValues objects. """Wrapper around FlagValues objects.
@@ -167,85 +213,60 @@ class StrWrapper(object):
raise KeyError(name) raise KeyError(name)
# Copied from gflags with small mods to get the naming correct. def DEFINE_string(name, default, help, flag_values=FLAGS):
# Originally gflags checks for the first module that is not gflags that is flag_values.define_string(name, default, help)
# in the call chain, we want to check for the first module that is not gflags
# and not this module.
def _GetCallingModule():
"""Returns the name of the module that's calling into this module.
We generally use this function to get the name of the module calling a
DEFINE_foo... function.
"""
# Walk down the stack to find the first globals dict that's not ours.
for depth in range(1, sys.getrecursionlimit()):
if not sys._getframe(depth).f_globals is globals():
module_name = __GetModuleName(sys._getframe(depth).f_globals)
if module_name == 'gflags':
continue
if module_name is not None:
return module_name
raise AssertionError("No module was found")
# Copied from gflags because it is a private function def DEFINE_integer(name, default, help, lower_bound=None, flag_values=FLAGS):
def __GetModuleName(globals_dict): # FIXME(markmc): ignoring lower_bound
"""Given a globals dict, returns the name of the module that defines it. flag_values.define_integer(name, default, help)
Args:
globals_dict: A dictionary that should correspond to an environment
providing the values of the globals.
Returns:
A string (the name of the module) or None (if the module could not
be identified.
"""
for name, module in sys.modules.iteritems():
if getattr(module, '__dict__', None) is globals_dict:
if name == '__main__':
return sys.argv[0]
return name
return None
def _wrapper(func): def DEFINE_bool(name, default, help, flag_values=FLAGS):
def _wrapped(*args, **kw): flag_values.define_bool(name, default, help)
kw.setdefault('flag_values', FLAGS)
func(*args, **kw)
_wrapped.func_name = func.func_name
return _wrapped
FLAGS = FlagValues() def DEFINE_boolean(name, default, help, flag_values=FLAGS):
gflags.FLAGS = FLAGS DEFINE_bool(name, default, help, flag_values)
gflags._GetCallingModule = _GetCallingModule
DEFINE = _wrapper(gflags.DEFINE) def DEFINE_list(name, default, help, flag_values=FLAGS):
DEFINE_string = _wrapper(gflags.DEFINE_string) flag_values.define_list(name, default, help)
DEFINE_integer = _wrapper(gflags.DEFINE_integer)
DEFINE_bool = _wrapper(gflags.DEFINE_bool)
DEFINE_boolean = _wrapper(gflags.DEFINE_boolean) def DEFINE_float(name, default, help, flag_values=FLAGS):
DEFINE_float = _wrapper(gflags.DEFINE_float) flag_values.define_float(name, default, help)
DEFINE_enum = _wrapper(gflags.DEFINE_enum)
DEFINE_list = _wrapper(gflags.DEFINE_list)
DEFINE_spaceseplist = _wrapper(gflags.DEFINE_spaceseplist) def DEFINE_multistring(name, default, help, flag_values=FLAGS):
DEFINE_multistring = _wrapper(gflags.DEFINE_multistring) flag_values.define_multistring(name, default, help)
DEFINE_multi_int = _wrapper(gflags.DEFINE_multi_int)
DEFINE_flag = _wrapper(gflags.DEFINE_flag)
HelpFlag = gflags.HelpFlag class UnrecognizedFlag(Exception):
HelpshortFlag = gflags.HelpshortFlag pass
HelpXMLFlag = gflags.HelpXMLFlag
def DECLARE(name, module_string, flag_values=FLAGS): def DECLARE(name, module_string, flag_values=FLAGS):
if module_string not in sys.modules: if module_string not in sys.modules:
__import__(module_string, globals(), locals()) __import__(module_string, globals(), locals())
if name not in flag_values: if name not in flag_values:
raise gflags.UnrecognizedFlag( raise UnrecognizedFlag('%s not defined by %s' % (name, module_string))
"%s not defined by %s" % (name, module_string))
def DEFINE_flag(flag):
pass
class HelpFlag:
pass
class HelpshortFlag:
pass
class HelpXMLFlag:
pass
def _get_my_ip(): def _get_my_ip():

View File

@@ -3,6 +3,7 @@
# Copyright 2010 United States Government as represented by the # Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration. # Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved. # All Rights Reserved.
# Copyright 2011 Red Hat, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@@ -16,6 +17,10 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import exceptions
import os
import tempfile
from nova import exception from nova import exception
from nova import flags from nova import flags
from nova import test from nova import test
@@ -64,6 +69,37 @@ class FlagsTestCase(test.TestCase):
self.assertEqual(self.FLAGS.false, True) self.assertEqual(self.FLAGS.false, True)
self.assertEqual(self.FLAGS.true, False) self.assertEqual(self.FLAGS.true, False)
def test_define_float(self):
flags.DEFINE_float('float', 6.66, 'desc', flag_values=self.FLAGS)
self.assertEqual(self.FLAGS.float, 6.66)
def test_define_multistring(self):
flags.DEFINE_multistring('multi', [], 'desc', flag_values=self.FLAGS)
argv = ['flags_test', '--multi', 'foo', '--multi', 'bar']
self.FLAGS(argv)
self.assertEqual(self.FLAGS.multi, ['foo', 'bar'])
def test_define_list(self):
flags.DEFINE_list('list', ['foo'], 'desc', flag_values=self.FLAGS)
self.assert_(self.FLAGS['list'])
self.assertEqual(self.FLAGS.list, ['foo'])
argv = ['flags_test', '--list=a,b,c,d']
self.FLAGS(argv)
self.assertEqual(self.FLAGS.list, ['a', 'b', 'c', 'd'])
def test_error(self):
flags.DEFINE_integer('error', 1, 'desc', flag_values=self.FLAGS)
self.assertEqual(self.FLAGS.error, 1)
argv = ['flags_test', '--error=foo']
self.assertRaises(exceptions.SystemExit, self.FLAGS, argv)
def test_declare(self): def test_declare(self):
self.assert_('answer' not in self.global_FLAGS) self.assert_('answer' not in self.global_FLAGS)
flags.DECLARE('answer', 'nova.tests.declare_flags') flags.DECLARE('answer', 'nova.tests.declare_flags')
@@ -76,6 +112,14 @@ class FlagsTestCase(test.TestCase):
flags.DECLARE('answer', 'nova.tests.declare_flags') flags.DECLARE('answer', 'nova.tests.declare_flags')
self.assertEqual(self.global_FLAGS.answer, 256) self.assertEqual(self.global_FLAGS.answer, 256)
def test_getopt_non_interspersed_args(self):
self.assert_('runtime_answer' not in self.global_FLAGS)
argv = ['flags_test', 'extra_arg', '--runtime_answer=60']
args = self.global_FLAGS(argv)
self.assertEqual(len(args), 3)
self.assertEqual(argv, args)
def test_runtime_and_unknown_flags(self): def test_runtime_and_unknown_flags(self):
self.assert_('runtime_answer' not in self.global_FLAGS) self.assert_('runtime_answer' not in self.global_FLAGS)
@@ -114,3 +158,49 @@ class FlagsTestCase(test.TestCase):
self.assertEqual(FLAGS.flags_unittest, 'foo') self.assertEqual(FLAGS.flags_unittest, 'foo')
FLAGS.flags_unittest = 'bar' FLAGS.flags_unittest = 'bar'
self.assertEqual(FLAGS.flags_unittest, 'bar') self.assertEqual(FLAGS.flags_unittest, 'bar')
def test_flag_overrides(self):
self.assertEqual(FLAGS.flags_unittest, 'foo')
self.flags(flags_unittest='bar')
self.assertEqual(FLAGS.flags_unittest, 'bar')
self.assertEqual(FLAGS['flags_unittest'].value, 'bar')
self.assertEqual(FLAGS.FlagValuesDict()['flags_unittest'], 'bar')
self.reset_flags()
self.assertEqual(FLAGS.flags_unittest, 'foo')
self.assertEqual(FLAGS['flags_unittest'].value, 'foo')
self.assertEqual(FLAGS.FlagValuesDict()['flags_unittest'], 'foo')
def test_flagfile(self):
flags.DEFINE_string('string', 'default', 'desc',
flag_values=self.FLAGS)
flags.DEFINE_integer('int', 1, 'desc', flag_values=self.FLAGS)
flags.DEFINE_bool('false', False, 'desc', flag_values=self.FLAGS)
flags.DEFINE_bool('true', True, 'desc', flag_values=self.FLAGS)
(fd, path) = tempfile.mkstemp(prefix='nova', suffix='.flags')
try:
os.write(fd, '--string=foo\n--int=2\n--false\n--notrue\n')
os.close(fd)
self.FLAGS(['flags_test', '--flagfile=' + path])
self.assertEqual(self.FLAGS.string, 'foo')
self.assertEqual(self.FLAGS.int, 2)
self.assertEqual(self.FLAGS.false, True)
self.assertEqual(self.FLAGS.true, False)
finally:
os.remove(path)
def test_defaults(self):
flags.DEFINE_string('foo', 'bar', 'help', flag_values=self.FLAGS)
self.assertEqual(self.FLAGS.foo, 'bar')
self.FLAGS['foo'].SetDefault('blaa')
self.assertEqual(self.FLAGS.foo, 'blaa')
def test_templated_values(self):
flags.DEFINE_string('foo', 'foo', 'help', flag_values=self.FLAGS)
flags.DEFINE_string('bar', 'bar', 'help', flag_values=self.FLAGS)
flags.DEFINE_string('blaa', '$foo$bar', 'help', flag_values=self.FLAGS)
self.assertEqual(self.FLAGS.blaa, 'foobar')