Add some useful features to our flags
* No longer dies if there are unknown flags. * Allows you to declare that you will use a flag from another file * Allows you to import new flags at runtime and reparses the original arguments to fill them once they are accessed.
This commit is contained in:
133
nova/flags.py
133
nova/flags.py
@@ -21,16 +21,137 @@ Package-level global flags are defined here, the rest are defined
|
||||
where they're used.
|
||||
"""
|
||||
|
||||
import getopt
|
||||
import socket
|
||||
import sys
|
||||
|
||||
import gflags
|
||||
|
||||
|
||||
from gflags import *
|
||||
class FlagValues(gflags.FlagValues):
|
||||
def __init__(self):
|
||||
gflags.FlagValues.__init__(self)
|
||||
self.__dict__['__dirty'] = []
|
||||
self.__dict__['__was_already_parsed'] = False
|
||||
self.__dict__['__stored_argv'] = []
|
||||
|
||||
def __call__(self, argv):
|
||||
# We're doing some hacky stuff here so that we don't have to copy
|
||||
# out all the code of the original verbatim and then tweak a few lines.
|
||||
# We're hijacking the output of getopt so we can still return the
|
||||
# leftover args at the end
|
||||
sneaky_unparsed_args = {"value": None}
|
||||
original_argv = list(argv)
|
||||
|
||||
if self.IsGnuGetOpt():
|
||||
orig_getopt = getattr(getopt, 'gnu_getopt')
|
||||
orig_name = 'gnu_getopt'
|
||||
else:
|
||||
orig_getopt = getattr(getopt, 'getopt')
|
||||
orig_name = 'getopt'
|
||||
|
||||
def _sneaky(*args, **kw):
|
||||
optlist, unparsed_args = orig_getopt(*args, **kw)
|
||||
sneaky_unparsed_args['value'] = unparsed_args
|
||||
return optlist, unparsed_args
|
||||
|
||||
try:
|
||||
setattr(getopt, orig_name, _sneaky)
|
||||
args = gflags.FlagValues.__call__(self, argv)
|
||||
except gflags.UnrecognizedFlagError:
|
||||
# Undefined args were found, for now we don't care so just
|
||||
# act like everything went well
|
||||
# (these three lines are copied pretty much verbatim from the end
|
||||
# of the __call__ function we are wrapping)
|
||||
unparsed_args = sneaky_unparsed_args['value']
|
||||
if unparsed_args:
|
||||
if self.IsGnuGetOpt():
|
||||
args = argv[:1] + unparsed
|
||||
else:
|
||||
args = argv[:1] + original_argv[-len(unparsed_args):]
|
||||
else:
|
||||
args = argv[:1]
|
||||
finally:
|
||||
setattr(getopt, orig_name, orig_getopt)
|
||||
|
||||
# Store the arguments for later, we'll need them for new flags
|
||||
# added at runtime
|
||||
self.__dict__['__stored_argv'] = original_argv
|
||||
self.__dict__['__was_already_parsed'] = True
|
||||
self.ClearDirty()
|
||||
return args
|
||||
|
||||
def SetDirty(self, name):
|
||||
"""Mark a flag as dirty so that accessing it will case a reparse."""
|
||||
self.__dict__['__dirty'].append(name)
|
||||
|
||||
def IsDirty(self, name):
|
||||
return name in self.__dict__['__dirty']
|
||||
|
||||
def ClearDirty(self):
|
||||
self.__dict__['__is_dirty'] = []
|
||||
|
||||
def WasAlreadyParsed(self):
|
||||
return self.__dict__['__was_already_parsed']
|
||||
|
||||
def ParseNewFlags(self):
|
||||
if '__stored_argv' not in self.__dict__:
|
||||
return
|
||||
new_flags = FlagValues()
|
||||
for k in self.__dict__['__dirty']:
|
||||
new_flags[k] = gflags.FlagValues.__getitem__(self, k)
|
||||
|
||||
new_flags(self.__dict__['__stored_argv'])
|
||||
for k in self.__dict__['__dirty']:
|
||||
setattr(self, k, getattr(new_flags, k))
|
||||
self.ClearDirty()
|
||||
|
||||
def __setitem__(self, name, flag):
|
||||
gflags.FlagValues.__setitem__(self, name, flag)
|
||||
if self.WasAlreadyParsed():
|
||||
self.SetDirty(name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
if self.IsDirty(name):
|
||||
self.ParseNewFlags()
|
||||
return gflags.FlagValues.__getitem__(self, name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self.IsDirty(name):
|
||||
self.ParseNewFlags()
|
||||
return gflags.FlagValues.__getattr__(self, name)
|
||||
|
||||
|
||||
FLAGS = FlagValues()
|
||||
|
||||
|
||||
def party_wrapper(func):
|
||||
def _wrapped(*args, **kw):
|
||||
kw.setdefault('flag_values', FLAGS)
|
||||
func(*args, **kw)
|
||||
_wrapped.func_name = func.func_name
|
||||
return _wrapped
|
||||
|
||||
|
||||
DEFINE_string = party_wrapper(gflags.DEFINE_string)
|
||||
DEFINE_integer = party_wrapper(gflags.DEFINE_integer)
|
||||
DEFINE_bool = party_wrapper(gflags.DEFINE_bool)
|
||||
DEFINE_boolean = party_wrapper(gflags.DEFINE_boolean)
|
||||
DEFINE_float = party_wrapper(gflags.DEFINE_float)
|
||||
DEFINE_enum = party_wrapper(gflags.DEFINE_enum)
|
||||
DEFINE_list = party_wrapper(gflags.DEFINE_list)
|
||||
DEFINE_spaceseplist = party_wrapper(gflags.DEFINE_spaceseplist)
|
||||
DEFINE_multistring = party_wrapper(gflags.DEFINE_multistring)
|
||||
DEFINE_multi_int = party_wrapper(gflags.DEFINE_multi_int)
|
||||
|
||||
|
||||
def DECLARE(name, module_string, flag_values=FLAGS):
|
||||
if module_string not in sys.modules:
|
||||
__import__(module_string, globals(), locals())
|
||||
if name not in flag_values:
|
||||
raise gflags.UnrecognizedFlag(
|
||||
"%s not defined by %s" % (name, module_string))
|
||||
|
||||
# This keeps pylint from barfing on the imports
|
||||
FLAGS = FLAGS
|
||||
DEFINE_string = DEFINE_string
|
||||
DEFINE_integer = DEFINE_integer
|
||||
DEFINE_bool = DEFINE_bool
|
||||
|
||||
# __GLOBAL FLAGS ONLY__
|
||||
# Define any app-specific flags in their own files, docs at:
|
||||
|
5
nova/tests/declare_flags.py
Normal file
5
nova/tests/declare_flags.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from nova import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_integer('answer', 42, 'test flag')
|
94
nova/tests/flags_unittest.py
Normal file
94
nova/tests/flags_unittest.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright 2010 United States Government as represented by the
|
||||
# Administrator of the National Aeronautics and Space Administration.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import reactor
|
||||
from xml.etree import ElementTree
|
||||
|
||||
from nova import exception
|
||||
from nova import flags
|
||||
from nova import process
|
||||
from nova import test
|
||||
from nova import utils
|
||||
|
||||
|
||||
class FlagsTestCase(test.TrialTestCase):
|
||||
def setUp(self):
|
||||
super(FlagsTestCase, self).setUp()
|
||||
self.FLAGS = flags.FlagValues()
|
||||
self.global_FLAGS = flags.FLAGS
|
||||
|
||||
def test_define(self):
|
||||
self.assert_('string' not in self.FLAGS)
|
||||
self.assert_('int' not in self.FLAGS)
|
||||
self.assert_('false' not in self.FLAGS)
|
||||
self.assert_('true' not in self.FLAGS)
|
||||
|
||||
flags.DEFINE_string('string', 'default', 'desc', flag_values=self.FLAGS)
|
||||
flags.DEFINE_integer('int', 1, 'desc', flag_values=self.FLAGS)
|
||||
flags.DEFINE_bool('false', False, 'desc', flag_values=self.FLAGS)
|
||||
flags.DEFINE_bool('true', True, 'desc', flag_values=self.FLAGS)
|
||||
|
||||
self.assert_(self.FLAGS['string'])
|
||||
self.assert_(self.FLAGS['int'])
|
||||
self.assert_(self.FLAGS['false'])
|
||||
self.assert_(self.FLAGS['true'])
|
||||
self.assertEqual(self.FLAGS.string, 'default')
|
||||
self.assertEqual(self.FLAGS.int, 1)
|
||||
self.assertEqual(self.FLAGS.false, False)
|
||||
self.assertEqual(self.FLAGS.true, True)
|
||||
|
||||
argv = ['flags_test',
|
||||
'--string', 'foo',
|
||||
'--int', '2',
|
||||
'--false',
|
||||
'--notrue']
|
||||
|
||||
self.FLAGS(argv)
|
||||
self.assertEqual(self.FLAGS.string, 'foo')
|
||||
self.assertEqual(self.FLAGS.int, 2)
|
||||
self.assertEqual(self.FLAGS.false, True)
|
||||
self.assertEqual(self.FLAGS.true, False)
|
||||
|
||||
def test_declare(self):
|
||||
self.assert_('answer' not in self.global_FLAGS)
|
||||
flags.DECLARE('answer', 'nova.tests.declare_flags')
|
||||
self.assert_('answer' in self.global_FLAGS)
|
||||
self.assertEqual(self.global_FLAGS.answer, 42)
|
||||
|
||||
# Make sure we don't overwrite anything
|
||||
self.global_FLAGS.answer = 256
|
||||
self.assertEqual(self.global_FLAGS.answer, 256)
|
||||
flags.DECLARE('answer', 'nova.tests.declare_flags')
|
||||
self.assertEqual(self.global_FLAGS.answer, 256)
|
||||
|
||||
def test_runtime_and_unknown_flags(self):
|
||||
self.assert_('runtime_answer' not in self.global_FLAGS)
|
||||
|
||||
argv = ['flags_test', '--runtime_answer=60', 'extra_arg']
|
||||
args = self.global_FLAGS(argv)
|
||||
self.assertEqual(len(args), 2)
|
||||
self.assertEqual(args[1], 'extra_arg')
|
||||
|
||||
self.assert_('runtime_answer' not in self.global_FLAGS)
|
||||
|
||||
import nova.tests.runtime_flags
|
||||
|
||||
self.assert_('runtime_answer' in self.global_FLAGS)
|
||||
self.assertEqual(self.global_FLAGS.runtime_answer, 60)
|
5
nova/tests/runtime_flags.py
Normal file
5
nova/tests/runtime_flags.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from nova import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_integer('runtime_answer', 54, 'test flag')
|
@@ -54,6 +54,7 @@ from nova.tests.auth_unittest import *
|
||||
from nova.tests.api_unittest import *
|
||||
from nova.tests.cloud_unittest import *
|
||||
from nova.tests.compute_unittest import *
|
||||
from nova.tests.flags_unittest import *
|
||||
from nova.tests.model_unittest import *
|
||||
from nova.tests.network_unittest import *
|
||||
from nova.tests.objectstore_unittest import *
|
||||
|
Reference in New Issue
Block a user