Made packstack.installer.setup_params pep8 compliant and added unit test for this module

Change-Id: Iff9d3f52ba3bc845d7bf01816dae1b4160581a3f
This commit is contained in:
Martin Magr 2013-02-28 17:00:42 +01:00
parent efdd25cb16
commit 0fdc119fcd
6 changed files with 358 additions and 192 deletions

View File

@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
import copy
from types import GeneratorType
# taken from Django.utils.datastructures
class SortedDict(dict):
"""
A dictionary that keeps its keys in the order in which they're inserted.
"""
def __new__(cls, *args, **kwargs):
instance = super(SortedDict, cls).__new__(cls, *args, **kwargs)
instance.keyOrder = []
return instance
def __init__(self, data=None):
if data is None:
data = {}
elif isinstance(data, GeneratorType):
# Unfortunately we need to be able to read a generator twice. Once
# to get the data into self with our super().__init__ call and a
# second time to setup keyOrder correctly
data = list(data)
super(SortedDict, self).__init__(data)
if isinstance(data, dict):
self.keyOrder = data.keys()
else:
self.keyOrder = []
seen = set()
for key, value in data:
if key not in seen:
self.keyOrder.append(key)
seen.add(key)
def __deepcopy__(self, memo):
return self.__class__([(key, copy.deepcopy(value, memo))
for key, value in self.iteritems()])
def __setitem__(self, key, value):
if key not in self:
self.keyOrder.append(key)
super(SortedDict, self).__setitem__(key, value)
def __delitem__(self, key):
super(SortedDict, self).__delitem__(key)
self.keyOrder.remove(key)
def __iter__(self):
return iter(self.keyOrder)
def pop(self, k, *args):
result = super(SortedDict, self).pop(k, *args)
try:
self.keyOrder.remove(k)
except ValueError:
# Key wasn't in the dictionary in the first place. No problem.
pass
return result
def popitem(self):
result = super(SortedDict, self).popitem()
self.keyOrder.remove(result[0])
return result
def items(self):
return zip(self.keyOrder, self.values())
def iteritems(self):
for key in self.keyOrder:
yield key, self[key]
def keys(self):
return self.keyOrder[:]
def iterkeys(self):
return iter(self.keyOrder)
def values(self):
return map(self.__getitem__, self.keyOrder)
def itervalues(self):
for key in self.keyOrder:
yield self[key]
def update(self, dict_):
for k, v in dict_.iteritems():
self[k] = v
def setdefault(self, key, default):
if key not in self:
self.keyOrder.append(key)
return super(SortedDict, self).setdefault(key, default)
def value_for_index(self, index):
"""Returns the value of the item at the given zero-based index."""
return self[self.keyOrder[index]]
def insert(self, index, key, value):
"""Inserts the key, value pair before the item with the given index."""
if key in self.keyOrder:
n = self.keyOrder.index(key)
del self.keyOrder[n]
if n < index:
index -= 1
self.keyOrder.insert(index, key)
super(SortedDict, self).__setitem__(key, value)
def copy(self):
"""Returns a copy of this object."""
# This way of initializing the copy means it works for subclasses, too.
obj = self.__class__(self)
obj.keyOrder = self.keyOrder[:]
return obj
def __repr__(self):
"""
Replaces the normal dict.__repr__ with a version that returns the keys
in their sorted order.
"""
return '{%s}' % ', '.join(['%r: %r' % (k, v) for k, v in self.items()])
def clear(self):
super(SortedDict, self).clear()
self.keyOrder = []

View File

@ -70,37 +70,37 @@ def _getInputFromUser(param):
userInput = None
try:
if param.getKey("USE_DEFAULT"):
logging.debug("setting default value (%s) for key (%s)" % (mask(param.getKey("DEFAULT_VALUE")), param.getKey("CONF_NAME")))
controller.CONF[param.getKey("CONF_NAME")] = param.getKey("DEFAULT_VALUE")
if param.USE_DEFAULT:
logging.debug("setting default value (%s) for key (%s)" % (mask(param.DEFAULT_VALUE), param.CONF_NAME))
controller.CONF[param.CONF_NAME] = param.DEFAULT_VALUE
else:
while loop:
# If the value was not supplied by the command line flags
if not commandLineValues.has_key(param.getKey("CONF_NAME")):
if not commandLineValues.has_key(param.CONF_NAME):
message = StringIO()
message.write(param.getKey("PROMPT"))
message.write(param.PROMPT)
val_list = param.getKey("VALIDATORS") or []
val_list = param.VALIDATORS or []
if validators.validate_regexp not in val_list \
and param.getKey("OPTION_LIST"):
message.write(" [%s]" % "|".join(param.getKey("OPTION_LIST")))
and param.OPTION_LIST:
message.write(" [%s]" % "|".join(param.OPTION_LIST))
if param.getKey("DEFAULT_VALUE"):
message.write(" [%s] " % (str(param.getKey("DEFAULT_VALUE"))))
if param.DEFAULT_VALUE:
message.write(" [%s] " % (str(param.DEFAULT_VALUE)))
message.write(": ")
message.seek(0)
#mask password or hidden fields
if (param.getKey("MASK_INPUT")):
userInput = getpass.getpass("%s :" % (param.getKey("PROMPT")))
if (param.MASK_INPUT):
userInput = getpass.getpass("%s :" % (param.PROMPT))
else:
userInput = raw_input(message.read())
else:
userInput = commandLineValues[param.getKey("CONF_NAME")]
userInput = commandLineValues[param.CONF_NAME]
# If DEFAULT_VALUE is set and user did not input anything
if userInput == "" and len(str(param.getKey("DEFAULT_VALUE"))) > 0:
userInput = param.getKey("DEFAULT_VALUE")
if userInput == "" and len(str(param.DEFAULT_VALUE)) > 0:
userInput = param.DEFAULT_VALUE
# Param processing
userInput = process_param_value(param, userInput)
@ -108,31 +108,31 @@ def _getInputFromUser(param):
# If param requires validation
try:
validate_param_value(param, userInput)
controller.CONF[param.getKey("CONF_NAME")] = userInput
controller.CONF[param.CONF_NAME] = userInput
loop = False
except ParamValidationError:
if param.getKey("LOOSE_VALIDATION"):
if param.LOOSE_VALIDATION:
# If validation failed but LOOSE_VALIDATION is true, ask user
answer = _askYesNo("User input failed validation, "
"do you still wish to use it")
loop = not answer
if answer:
controller.CONF[param.getKey("CONF_NAME")] = userInput
controller.CONF[param.CONF_NAME] = userInput
continue
else:
if commandLineValues.has_key(param.getKey("CONF_NAME")):
del commandLineValues[param.getKey("CONF_NAME")]
if commandLineValues.has_key(param.CONF_NAME):
del commandLineValues[param.CONF_NAME]
else:
# Delete value from commandLineValues so that we will prompt the user for input
if commandLineValues.has_key(param.getKey("CONF_NAME")):
del commandLineValues[param.getKey("CONF_NAME")]
if commandLineValues.has_key(param.CONF_NAME):
del commandLineValues[param.CONF_NAME]
loop = True
except KeyboardInterrupt:
print "" # add the new line so messages wont be displayed in the same line as the question
raise
except:
logging.error(traceback.format_exc())
raise Exception(output_messages.ERR_EXP_READ_INPUT_PARAM % (param.getKey("CONF_NAME")))
raise Exception(output_messages.ERR_EXP_READ_INPUT_PARAM % (param.CONF_NAME))
def input_param(param):
"""
@ -141,18 +141,18 @@ def input_param(param):
"""
# We need to check if a param needs confirmation, (i.e. ask user twice)
# Do not validate if it was given from the command line
if (param.getKey("NEED_CONFIRM") and not commandLineValues.has_key(param.getKey("CONF_NAME"))):
if (param.NEED_CONFIRM and not commandLineValues.has_key(param.CONF_NAME)):
#create a copy of the param so we can call it twice
confirmedParam = copy.deepcopy(param)
confirmedParamName = param.getKey("CONF_NAME") + "_CONFIRMED"
confirmedParam.setKey("CONF_NAME", confirmedParamName)
confirmedParam.setKey("PROMPT", output_messages.INFO_CONF_PARAMS_PASSWD_CONFIRM_PROMPT)
confirmedParam.setKey("VALIDATORS", [validators.validate_not_empty])
confirmedParamName = param.CONF_NAME + "_CONFIRMED"
confirmedParam.CONF_NAME = confirmedParamName
confirmedParam.PROMPT = output_messages.INFO_CONF_PARAMS_PASSWD_CONFIRM_PROMPT
confirmedParam.VALIDATORS = [validators.validate_not_empty]
# Now get both values from user (with existing validations
while True:
_getInputFromUser(param)
_getInputFromUser(confirmedParam)
if controller.CONF[param.getKey("CONF_NAME")] == controller.CONF[confirmedParamName]:
if controller.CONF[param.CONF_NAME] == controller.CONF[confirmedParamName]:
logging.debug("Param confirmation passed, value for both questions is identical")
break
else:
@ -192,10 +192,10 @@ def _addDefaultsToMaskedValueSet():
"""
global masked_value_set
for group in controller.getAllGroups():
for param in group.getAllParams():
for param in group.parameters.itervalues():
# Keep default password values masked, but ignore default empty values
if ((param.getKey("MASK_INPUT") == True) and param.getKey("DEFAULT_VALUE") != ""):
masked_value_set.add(param.getKey("DEFAULT_VALUE"))
if ((param.MASK_INPUT == True) and param.DEFAULT_VALUE != ""):
masked_value_set.add(param.DEFAULT_VALUE)
def _updateMaskedValueSet():
"""
@ -261,11 +261,11 @@ def maskString(str):
return str
def validate_param_value(param, value):
cname = param.getKey("CONF_NAME")
cname = param.CONF_NAME
logging.debug("Validating parameter %s." % cname)
val_list = param.getKey("VALIDATORS") or []
opt_list = param.getKey("OPTION_LIST")
val_list = param.VALIDATORS or []
opt_list = param.OPTION_LIST
for val_func in val_list:
try:
val_func(value, opt_list)
@ -275,10 +275,10 @@ def validate_param_value(param, value):
def process_param_value(param, value):
_value = value
processors = param.getKey("PROCESSORS") or []
processors = param.PROCESSORS or []
for proc_func in processors:
logging.debug("Processing value of parameter "
"%s." % param.getKey("CONF_NAME"))
"%s." % param.CONF_NAME)
try:
new_value = proc_func(_value, controller.CONF)
if new_value != _value:
@ -290,7 +290,7 @@ def process_param_value(param, value):
"value: %s" % _value)
except processors.ParamProcessingError, ex:
print ("Value processing of parameter %s "
"failed.\n%s" % (param.getKey("CONF_NAME"), ex))
"failed.\n%s" % (param.CONF_NAME, ex))
raise
return _value
@ -336,7 +336,7 @@ def _loadParamFromFile(config, section, paramName):
validate_param_value(param, value)
# Keep param value in our never ending global conf
controller.CONF[param.getKey("CONF_NAME")] = value
controller.CONF[param.CONF_NAME] = value
return value
@ -359,32 +359,32 @@ def _handleAnswerFileParams(answerFile):
# Handle pre conditions for group
preConditionValue = True
if group.getKey("PRE_CONDITION"):
preConditionValue = _handleGroupCondition(fconf, group.getKey("PRE_CONDITION"), preConditionValue)
if group.PRE_CONDITION:
preConditionValue = _handleGroupCondition(fconf, group.PRE_CONDITION, preConditionValue)
# Handle pre condition match with case insensitive values
logging.info("Comparing pre- conditions, value: '%s', and match: '%s'" % (preConditionValue, group.getKey("PRE_CONDITION_MATCH")))
if utils.compareStrIgnoreCase(preConditionValue, group.getKey("PRE_CONDITION_MATCH")):
for param in group.getAllParams():
_loadParamFromFile(fconf, "general", param.getKey("CONF_NAME"))
logging.info("Comparing pre- conditions, value: '%s', and match: '%s'" % (preConditionValue, group.PRE_CONDITION_MATCH))
if utils.compareStrIgnoreCase(preConditionValue, group.PRE_CONDITION_MATCH):
for param in group.parameters.itervalues():
_loadParamFromFile(fconf, "general", param.CONF_NAME)
# Handle post conditions for group only if pre condition passed
postConditionValue = True
if group.getKey("POST_CONDITION"):
postConditionValue = _handleGroupCondition(fconf, group.getKey("POST_CONDITION"), postConditionValue)
if group.POST_CONDITION:
postConditionValue = _handleGroupCondition(fconf, group.POST_CONDITION, postConditionValue)
# Handle post condition match for group
if not utils.compareStrIgnoreCase(postConditionValue, group.getKey("POST_CONDITION_MATCH")):
if not utils.compareStrIgnoreCase(postConditionValue, group.POST_CONDITION_MATCH):
logging.error("The group condition (%s) returned: %s, which differs from the excpeted output: %s"%\
(group.getKey("GROUP_NAME"), postConditionValue, group.getKey("POST_CONDITION_MATCH")))
(group.GROUP_NAME, postConditionValue, group.POST_CONDITION_MATCH))
raise ValueError(output_messages.ERR_EXP_GROUP_VALIDATION_ANS_FILE%\
(group.getKey("GROUP_NAME"), postConditionValue, group.getKey("POST_CONDITION_MATCH")))
(group.GROUP_NAME, postConditionValue, group.POST_CONDITION_MATCH))
else:
logging.debug("condition (%s) passed" % group.getKey("POST_CONDITION"))
logging.debug("condition (%s) passed" % group.POST_CONDITION)
else:
logging.debug("no post condition check for group %s" % group.getKey("GROUP_NAME"))
logging.debug("no post condition check for group %s" % group.GROUP_NAME)
else:
logging.debug("skipping params group %s since value of group validation is %s" % (group.getKey("GROUP_NAME"), preConditionValue))
logging.debug("skipping params group %s since value of group validation is %s" % (group.GROUP_NAME, preConditionValue))
except Exception as e:
logging.error(traceback.format_exc())
@ -408,24 +408,24 @@ def _getanswerfilepath():
def _handleInteractiveParams():
try:
logging.debug("Groups: %s" % ', '.join([x.getKey("GROUP_NAME") for x in controller.getAllGroups()]))
logging.debug("Groups: %s" % ', '.join([x.GROUP_NAME for x in controller.getAllGroups()]))
for group in controller.getAllGroups():
preConditionValue = True
logging.debug("going over group %s" % group.getKey("GROUP_NAME"))
logging.debug("going over group %s" % group.GROUP_NAME)
# If pre_condition is set, get Value
if group.getKey("PRE_CONDITION"):
preConditionValue = _getConditionValue(group.getKey("PRE_CONDITION"))
if group.PRE_CONDITION:
preConditionValue = _getConditionValue(group.PRE_CONDITION)
inputLoop = True
# If we have a match, i.e. condition returned True, go over all params in the group
logging.info("Comparing pre-conditions; condition: '%s', and match: '%s'" % (preConditionValue, group.getKey("PRE_CONDITION_MATCH")))
if utils.compareStrIgnoreCase(preConditionValue, group.getKey("PRE_CONDITION_MATCH")):
logging.info("Comparing pre-conditions; condition: '%s', and match: '%s'" % (preConditionValue, group.PRE_CONDITION_MATCH))
if utils.compareStrIgnoreCase(preConditionValue, group.PRE_CONDITION_MATCH):
while inputLoop:
for param in group.getAllParams():
if not param.getKey("CONDITION"):
for param in group.parameters.itervalues():
if not param.CONDITION:
input_param(param)
#update password list, so we know to mask them
_updateMaskedValueSet()
@ -434,23 +434,23 @@ def _handleInteractiveParams():
# If group has a post condition, we check it after we get the input from
# all the params in the group. if the condition returns False, we loop over the group again
if group.getKey("POST_CONDITION"):
postConditionValue = _getConditionValue(group.getKey("POST_CONDITION"))
if group.POST_CONDITION:
postConditionValue = _getConditionValue(group.POST_CONDITION)
if postConditionValue == group.getKey("POST_CONDITION_MATCH"):
if postConditionValue == group.POST_CONDITION_MATCH:
inputLoop = False
else:
#we clear the value of all params in the group
#in order to re-input them by the user
for param in group.getAllParams():
if controller.CONF.has_key(param.getKey("CONF_NAME")):
del controller.CONF[param.getKey("CONF_NAME")]
if commandLineValues.has_key(param.getKey("CONF_NAME")):
del commandLineValues[param.getKey("CONF_NAME")]
for param in group.parameters.itervalues():
if controller.CONF.has_key(param.CONF_NAME):
del controller.CONF[param.CONF_NAME]
if commandLineValues.has_key(param.CONF_NAME):
del commandLineValues[param.CONF_NAME]
else:
inputLoop = False
else:
logging.debug("no post condition check for group %s" % group.getKey("GROUP_NAME"))
logging.debug("no post condition check for group %s" % group.GROUP_NAME)
path = _getanswerfilepath()
@ -498,34 +498,34 @@ def _displaySummary():
print "=" * (len(output_messages.INFO_DSPLY_PARAMS) - 1)
logging.info("*** User input summary ***")
for group in controller.getAllGroups():
for param in group.getAllParams():
if not param.getKey("USE_DEFAULT") and controller.CONF.has_key(param.getKey("CONF_NAME")):
cmdOption = param.getKey("CMD_OPTION")
for param in group.parameters.itervalues():
if not param.USE_DEFAULT and controller.CONF.has_key(param.CONF_NAME):
cmdOption = param.CMD_OPTION
l = 30 - len(cmdOption)
maskParam = param.getKey("MASK_INPUT")
maskParam = param.MASK_INPUT
# Only call mask on a value if the param has MASK_INPUT set to True
if maskParam:
logging.info("%s: %s" % (cmdOption, mask(controller.CONF[param.getKey("CONF_NAME")])))
print "%s:" % (cmdOption) + " " * l + mask(controller.CONF[param.getKey("CONF_NAME")])
logging.info("%s: %s" % (cmdOption, mask(controller.CONF[param.CONF_NAME])))
print "%s:" % (cmdOption) + " " * l + mask(controller.CONF[param.CONF_NAME])
else:
# Otherwise, log & display it as it is
logging.info("%s: %s" % (cmdOption, str(controller.CONF[param.getKey("CONF_NAME")])))
print "%s:" % (cmdOption) + " " * l + str(controller.CONF[param.getKey("CONF_NAME")])
logging.info("%s: %s" % (cmdOption, str(controller.CONF[param.CONF_NAME])))
print "%s:" % (cmdOption) + " " * l + str(controller.CONF[param.CONF_NAME])
logging.info("*** User input summary ***")
answer = _askYesNo(output_messages.INFO_USE_PARAMS)
if not answer:
logging.debug("user chose to re-enter the user parameters")
for group in controller.getAllGroups():
for param in group.getAllParams():
if controller.CONF.has_key(param.getKey("CONF_NAME")):
if not param.getKey("MASK_INPUT"):
param.setKey("DEFAULT_VALUE", controller.CONF[param.getKey("CONF_NAME")])
for param in group.parameters.itervalues():
if controller.CONF.has_key(param.CONF_NAME):
if not param.MASK_INPUT:
param.DEFAULT_VALUE = controller.CONF[param.CONF_NAME]
# Remove the string from mask_value_set in order
# to remove values that might be over overwritten.
removeMaskString(controller.CONF[param.getKey("CONF_NAME")])
del controller.CONF[param.getKey("CONF_NAME")]
if commandLineValues.has_key(param.getKey("CONF_NAME")):
del commandLineValues[param.getKey("CONF_NAME")]
removeMaskString(controller.CONF[param.CONF_NAME])
del controller.CONF[param.CONF_NAME]
if commandLineValues.has_key(param.CONF_NAME):
del commandLineValues[param.CONF_NAME]
print ""
logging.debug("calling handleParams in interactive mode")
return _handleParams(None)
@ -564,10 +564,10 @@ def _summaryParamsToLog():
if len(controller.CONF) > 0:
logging.debug("*** The following params were used as user input:")
for group in controller.getAllGroups():
for param in group.getAllParams():
if controller.CONF.has_key(param.getKey("CONF_NAME")):
maskedValue = mask(controller.CONF[param.getKey("CONF_NAME")])
logging.debug("%s: %s" % (param.getKey("CMD_OPTION"), maskedValue ))
for param in group.parameters.itervalues():
if controller.CONF.has_key(param.CONF_NAME):
maskedValue = mask(controller.CONF[param.CONF_NAME])
logging.debug("%s: %s" % (param.CMD_OPTION, maskedValue ))
def runSequences():
@ -642,19 +642,19 @@ def generateAnswerFile(outputFile, overrides={}):
with os.fdopen(fd, "w") as ans_file:
ans_file.write("[general]%s" % os.linesep)
for group in controller.getAllGroups():
for param in group.getAllParams():
comm = param.getKey("USAGE") or ''
for param in group.parameters.itervalues():
comm = param.USAGE or ''
comm = textwrap.fill(comm,
initial_indent='%s# ' % sep,
subsequent_indent='# ',
break_long_words=False)
value = controller.CONF.get(param.getKey("CONF_NAME"),
param.getKey("DEFAULT_VALUE"))
value = controller.CONF.get(param.CONF_NAME,
param.DEFAULT_VALUE)
args = {'comment': comm,
'separator': sep,
'default_value': overrides.get(param.getKey("CONF_NAME"), value),
'conf_name': param.getKey("CONF_NAME")}
'default_value': overrides.get(param.CONF_NAME, value),
'conf_name': param.CONF_NAME}
ans_file.write(fmt % args)
def single_step_aio_install(options):
@ -691,10 +691,10 @@ def single_step_install(options):
hosts = options.install_hosts
hosts = [host.strip() for host in hosts.split(',')]
for group in controller.getAllGroups():
for param in group.getAllParams():
for param in group.parameters.itervalues():
# and directives that contain _HOST are set to the controller node
if param.getKey("CONF_NAME").find("_HOST") != -1:
overrides[param.getKey("CONF_NAME")] = hosts[0]
if param.CONF_NAME.find("_HOST") != -1:
overrides[param.CONF_NAME] = hosts[0]
# If there are more than one host, all but the first are a compute nodes
if len(hosts) > 1:
overrides["CONFIG_NOVA_COMPUTE_HOSTS"] = ','.join(hosts[1:])
@ -733,13 +733,13 @@ def initCmdLineParser():
# For each group, create a group option
for group in controller.getAllGroups():
groupParser = OptionGroup(parser, group.getKey("DESCRIPTION"))
groupParser = OptionGroup(parser, group.DESCRIPTION)
for param in group.getAllParams():
cmdOption = param.getKey("CMD_OPTION")
paramUsage = param.getKey("USAGE")
optionsList = param.getKey("OPTION_LIST")
useDefault = param.getKey("USE_DEFAULT")
for param in group.parameters.itervalues():
cmdOption = param.CMD_OPTION
paramUsage = param.USAGE
optionsList = param.OPTION_LIST
useDefault = param.USE_DEFAULT
if not useDefault:
groupParser.add_option("--%s" % cmdOption, help=paramUsage)
@ -756,14 +756,14 @@ def printOptions():
# For each group, create a group option
for group in controller.getAllGroups():
print "%s"%group.getKey("DESCRIPTION")
print "-"*len(group.getKey("DESCRIPTION"))
print "%s" % group.DESCRIPTION
print "-" * len(group.DESCRIPTION)
print
for param in group.getAllParams():
cmdOption = param.getKey("CONF_NAME")
paramUsage = param.getKey("USAGE")
optionsList = param.getKey("OPTION_LIST") or ""
for param in group.parameters.itervalues():
cmdOption = param.CONF_NAME
paramUsage = param.USAGE
optionsList = param.OPTION_LIST or ""
print "%s : %s %s"%(("**%s**"%str(cmdOption)).ljust(30), paramUsage, optionsList)
print
@ -849,9 +849,9 @@ def _set_command_line_values(options):
for key, value in options.__dict__.items():
# Replace the _ with - in the string since optparse replace _ with -
for group in controller.getAllGroups():
param = group.getParams("CMD_OPTION", key.replace("_","-"))
param = group.search("CMD_OPTION", key.replace("_","-"))
if len(param) > 0 and value:
commandLineValues[param[0].getKey("CONF_NAME")] = value
commandLineValues[param[0].CONF_NAME] = value
def main():
try:

View File

@ -105,7 +105,7 @@ class Controller(object):
def getGroupByName(self, groupName):
for group in self.getAllGroups():
if group.getKey("GROUP_NAME") == groupName:
if group.GROUP_NAME == groupName:
return group
return None
@ -114,7 +114,7 @@ class Controller(object):
def __getGroupIndexByDesc(self, name):
for group in self.getAllGroups():
if group.getKey("GROUP_NAME") == name:
if group.GROUP_NAME == name:
return self.__GROUPS.index(group)
return None
@ -131,14 +131,13 @@ class Controller(object):
def getParamByName(self, paramName):
for group in self.getAllGroups():
param = group.getParamByName(paramName)
if param:
return param
if paramName in group.parameters:
return group.parameters[paramName]
return None
def getParamKeyValue(self, paramName, keyName):
param = self.getParamByName(paramName)
if param:
return param.getKey(keyName)
return getattr(param, keyName)
else:
return None

View File

@ -1,86 +1,47 @@
# -*- coding: utf-8 -*-
"""
Container set for groups and parameters
"""
class Param(object):
allowed_keys = ('CMD_OPTION','USAGE','PROMPT','OPTION_LIST',
'PROCESSORS', 'VALIDATORS','DEFAULT_VALUE',
'MASK_INPUT', 'LOOSE_VALIDATION', 'CONF_NAME',
'USE_DEFAULT','NEED_CONFIRM','CONDITION')
from .datastructures import SortedDict
class Parameter(object):
allowed_keys = ('CONF_NAME', 'CMD_OPTION', 'USAGE', 'PROMPT',
'PROCESSORS', 'VALIDATORS', 'LOOSE_VALIDATION',
'DEFAULT_VALUE', 'USE_DEFAULT', 'OPTION_LIST',
'MASK_INPUT', 'NEED_CONFIRM','CONDITION')
def __init__(self, attributes=None):
if not attributes:
self.__ATTRIBUTES = {}.fromkeys(self.allowed_keys)
return
attributes = attributes or {}
defaults = {}.fromkeys(self.allowed_keys)
defaults.update(attributes)
self.__ATTRIBUTES = {}
for key, value in attributes.iteritems():
for key, value in defaults.iteritems():
if key not in self.allowed_keys:
raise KeyError('Given attribute %s is '
'not allowed' % key)
self.__ATTRIBUTES[key] = value
raise KeyError('Given attribute %s is not allowed' % key)
self.__dict__[key] = value
def setKey(self, key, value):
self.validateKey(key)
self.__ATTRIBUTES[key] = value
def getKey(self, key):
self.validateKey(key)
return self.__ATTRIBUTES.get(key)
class Group(Parameter):
allowed_keys = ('GROUP_NAME', 'DESCRIPTION', 'PRE_CONDITION',
'PRE_CONDITION_MATCH', 'POST_CONDITION',
'POST_CONDITION_MATCH')
def validateKey(self, key):
if key not in self.allowed_keys:
raise KeyError("%s is not a valid key" % key)
def __init__(self, attributes=None, parameters=None):
super(Group, self).__init__(attributes)
self.parameters = SortedDict()
for param in parameters or []:
self.parameters[param['CONF_NAME']] = Parameter(attributes=param)
class Group(Param):
allowed_keys = ('GROUP_NAME', 'DESCRIPTION', 'PRE_CONDITION', 'PRE_CONDITION_MATCH', 'POST_CONDITION', 'POST_CONDITION_MATCH')
def __init__(self, attributes={}, params=[]):
self.__PARAMS = []
Param.__init__(self, attributes)
for param in params:
self.addParam(param)
def addParam(self,paramDict):
p = Param(paramDict)
self.__PARAMS.append(p)
def getParamByName(self,paramName):
for param in self.__PARAMS:
if param.getKey("CONF_NAME") == paramName:
return param
return None
def getAllParams(self):
return self.__PARAMS
def getParams(self,paramKey, paramValue):
output = []
for param in self.__PARAMS:
if param.getKey(paramKey) == paramValue:
output.append(param)
return output
def __getParamIndexByDesc(self, name):
for param in self.getAllParams():
if param.getKey("CONF_NAME") == name:
return self.__PARAMS.index(param)
return None
def insertParamBeforeParam(self, paramName, param):
def search(self, attr, value):
"""
Insert a param before a named param.
i.e. if the specified param name is "update x", the new
param will be inserted BEFORE "update x"
Returns list of parameters which have given attribute of given
value.
"""
index = self.__getParamIndexByDesc(paramName)
if index == None:
index = len(self.getAllParams())
self.__PARAMS.insert(index, Param(param))
def removeParamByName(self, paramName):
self.__removeParams("CONF_NAME", paramName)
def __removeParams(self, paramKey, paramValue):
list = self.getParams(paramKey, paramValue)
for item in list:
self.__PARAMS.remove(item)
result = []
for param in self.parameters.itervalues():
if getattr(param, attr) == value:
result.append(param)
return result

View File

@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
"""
Test cases for packstack.installer.setup_params module.
"""
from unittest import TestCase
from ..test_base import PackstackTestCaseMixin
from packstack.installer.setup_params import *
class ParameterTestCase(PackstackTestCaseMixin, TestCase):
def setUp(self):
super(ParameterTestCase, self).setUp()
self.data = {
"CMD_OPTION": "mysql-host",
"USAGE": ("The IP address of the server on which to "
"install MySQL"),
"PROMPT": "Enter the IP address of the MySQL server",
"OPTION_LIST": [],
"VALIDATORS": [],
"DEFAULT_VALUE": "127.0.0.1",
"MASK_INPUT": False,
"LOOSE_VALIDATION": True,
"CONF_NAME": "CONFIG_MYSQL_HOST",
"USE_DEFAULT": False,
"NEED_CONFIRM": False,
"CONDITION": False}
def test_parameter_init(self):
"""
Test packstack.installer.setup_params.Parameter initialization
"""
param = Parameter(self.data)
for key, value in self.data.iteritems():
self.assertEqual(getattr(param, key), value)
def test_default_attribute(self):
"""
Test packstack.installer.setup_params.Parameter default value
"""
param = Parameter()
self.assertIsNone(param.PROCESSORS)
class GroupTestCase(PackstackTestCaseMixin, TestCase):
def setUp(self):
super(GroupTestCase, self).setUp()
self.attrs = {
"GROUP_NAME": "MYSQL",
"DESCRIPTION": "MySQL Config parameters",
"PRE_CONDITION": "y",
"PRE_CONDITION_MATCH": "y",
"POST_CONDITION": False,
"POST_CONDITION_MATCH": False}
self.params = [
{"CONF_NAME": "CONFIG_MYSQL_HOST", "PROMPT": "find_me"},
{"CONF_NAME": "CONFIG_MYSQL_USER"},
{"CONF_NAME": "CONFIG_MYSQL_PW"}]
def test_group_init(self):
"""Test packstack.installer.setup_params.Group initialization"""
group = Group(attributes=self.attrs, parameters=self.params)
for key, value in self.attrs.iteritems():
self.assertEqual(getattr(group, key), value)
for param in self.params:
self.assertIn(param['CONF_NAME'], group.parameters)
def test_search(self):
"""Test packstack.installer.setup_params.Group search method"""
group = Group(attributes=self.attrs, parameters=self.params)
param_list = group.search('PROMPT', 'find_me')
self.assertEqual(len(param_list), 1)
self.assertIsInstance(param_list[0], Parameter)
self.assertEqual(param_list[0].CONF_NAME, 'CONFIG_MYSQL_HOST')

View File

@ -86,3 +86,8 @@ class PackstackTestCaseMixin(object):
_msg = msg or ('%s is not a member of %s' % (first, second))
if first not in second:
raise AssertionError(_msg)
def assertIsNone(self, expr, msg=None):
_msg = msg or ('%s is not None' % expr)
if expr is not None:
raise AssertionError(_msg)