diff --git a/packstack/installer/datastructures.py b/packstack/installer/datastructures.py new file mode 100644 index 000000000..542f18dcb --- /dev/null +++ b/packstack/installer/datastructures.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- + +import copy +from types import GeneratorType + + +# taken from Django.utils.datastructures +class SortedDict(dict): + """ + A dictionary that keeps its keys in the order in which they're inserted. + """ + def __new__(cls, *args, **kwargs): + instance = super(SortedDict, cls).__new__(cls, *args, **kwargs) + instance.keyOrder = [] + return instance + + def __init__(self, data=None): + if data is None: + data = {} + elif isinstance(data, GeneratorType): + # Unfortunately we need to be able to read a generator twice. Once + # to get the data into self with our super().__init__ call and a + # second time to setup keyOrder correctly + data = list(data) + super(SortedDict, self).__init__(data) + if isinstance(data, dict): + self.keyOrder = data.keys() + else: + self.keyOrder = [] + seen = set() + for key, value in data: + if key not in seen: + self.keyOrder.append(key) + seen.add(key) + + def __deepcopy__(self, memo): + return self.__class__([(key, copy.deepcopy(value, memo)) + for key, value in self.iteritems()]) + + def __setitem__(self, key, value): + if key not in self: + self.keyOrder.append(key) + super(SortedDict, self).__setitem__(key, value) + + def __delitem__(self, key): + super(SortedDict, self).__delitem__(key) + self.keyOrder.remove(key) + + def __iter__(self): + return iter(self.keyOrder) + + def pop(self, k, *args): + result = super(SortedDict, self).pop(k, *args) + try: + self.keyOrder.remove(k) + except ValueError: + # Key wasn't in the dictionary in the first place. No problem. + pass + return result + + def popitem(self): + result = super(SortedDict, self).popitem() + self.keyOrder.remove(result[0]) + return result + + def items(self): + return zip(self.keyOrder, self.values()) + + def iteritems(self): + for key in self.keyOrder: + yield key, self[key] + + def keys(self): + return self.keyOrder[:] + + def iterkeys(self): + return iter(self.keyOrder) + + def values(self): + return map(self.__getitem__, self.keyOrder) + + def itervalues(self): + for key in self.keyOrder: + yield self[key] + + def update(self, dict_): + for k, v in dict_.iteritems(): + self[k] = v + + def setdefault(self, key, default): + if key not in self: + self.keyOrder.append(key) + return super(SortedDict, self).setdefault(key, default) + + def value_for_index(self, index): + """Returns the value of the item at the given zero-based index.""" + return self[self.keyOrder[index]] + + def insert(self, index, key, value): + """Inserts the key, value pair before the item with the given index.""" + if key in self.keyOrder: + n = self.keyOrder.index(key) + del self.keyOrder[n] + if n < index: + index -= 1 + self.keyOrder.insert(index, key) + super(SortedDict, self).__setitem__(key, value) + + def copy(self): + """Returns a copy of this object.""" + # This way of initializing the copy means it works for subclasses, too. + obj = self.__class__(self) + obj.keyOrder = self.keyOrder[:] + return obj + + def __repr__(self): + """ + Replaces the normal dict.__repr__ with a version that returns the keys + in their sorted order. + """ + return '{%s}' % ', '.join(['%r: %r' % (k, v) for k, v in self.items()]) + + def clear(self): + super(SortedDict, self).clear() + self.keyOrder = [] diff --git a/packstack/installer/run_setup.py b/packstack/installer/run_setup.py index 40364d8d2..d5d211683 100644 --- a/packstack/installer/run_setup.py +++ b/packstack/installer/run_setup.py @@ -70,37 +70,37 @@ def _getInputFromUser(param): userInput = None try: - if param.getKey("USE_DEFAULT"): - logging.debug("setting default value (%s) for key (%s)" % (mask(param.getKey("DEFAULT_VALUE")), param.getKey("CONF_NAME"))) - controller.CONF[param.getKey("CONF_NAME")] = param.getKey("DEFAULT_VALUE") + if param.USE_DEFAULT: + logging.debug("setting default value (%s) for key (%s)" % (mask(param.DEFAULT_VALUE), param.CONF_NAME)) + controller.CONF[param.CONF_NAME] = param.DEFAULT_VALUE else: while loop: # If the value was not supplied by the command line flags - if not commandLineValues.has_key(param.getKey("CONF_NAME")): + if not commandLineValues.has_key(param.CONF_NAME): message = StringIO() - message.write(param.getKey("PROMPT")) + message.write(param.PROMPT) - val_list = param.getKey("VALIDATORS") or [] + val_list = param.VALIDATORS or [] if validators.validate_regexp not in val_list \ - and param.getKey("OPTION_LIST"): - message.write(" [%s]" % "|".join(param.getKey("OPTION_LIST"))) + and param.OPTION_LIST: + message.write(" [%s]" % "|".join(param.OPTION_LIST)) - if param.getKey("DEFAULT_VALUE"): - message.write(" [%s] " % (str(param.getKey("DEFAULT_VALUE")))) + if param.DEFAULT_VALUE: + message.write(" [%s] " % (str(param.DEFAULT_VALUE))) message.write(": ") message.seek(0) #mask password or hidden fields - if (param.getKey("MASK_INPUT")): - userInput = getpass.getpass("%s :" % (param.getKey("PROMPT"))) + if (param.MASK_INPUT): + userInput = getpass.getpass("%s :" % (param.PROMPT)) else: userInput = raw_input(message.read()) else: - userInput = commandLineValues[param.getKey("CONF_NAME")] + userInput = commandLineValues[param.CONF_NAME] # If DEFAULT_VALUE is set and user did not input anything - if userInput == "" and len(str(param.getKey("DEFAULT_VALUE"))) > 0: - userInput = param.getKey("DEFAULT_VALUE") + if userInput == "" and len(str(param.DEFAULT_VALUE)) > 0: + userInput = param.DEFAULT_VALUE # Param processing userInput = process_param_value(param, userInput) @@ -108,31 +108,31 @@ def _getInputFromUser(param): # If param requires validation try: validate_param_value(param, userInput) - controller.CONF[param.getKey("CONF_NAME")] = userInput + controller.CONF[param.CONF_NAME] = userInput loop = False except ParamValidationError: - if param.getKey("LOOSE_VALIDATION"): + if param.LOOSE_VALIDATION: # If validation failed but LOOSE_VALIDATION is true, ask user answer = _askYesNo("User input failed validation, " "do you still wish to use it") loop = not answer if answer: - controller.CONF[param.getKey("CONF_NAME")] = userInput + controller.CONF[param.CONF_NAME] = userInput continue else: - if commandLineValues.has_key(param.getKey("CONF_NAME")): - del commandLineValues[param.getKey("CONF_NAME")] + if commandLineValues.has_key(param.CONF_NAME): + del commandLineValues[param.CONF_NAME] else: # Delete value from commandLineValues so that we will prompt the user for input - if commandLineValues.has_key(param.getKey("CONF_NAME")): - del commandLineValues[param.getKey("CONF_NAME")] + if commandLineValues.has_key(param.CONF_NAME): + del commandLineValues[param.CONF_NAME] loop = True except KeyboardInterrupt: print "" # add the new line so messages wont be displayed in the same line as the question raise except: logging.error(traceback.format_exc()) - raise Exception(output_messages.ERR_EXP_READ_INPUT_PARAM % (param.getKey("CONF_NAME"))) + raise Exception(output_messages.ERR_EXP_READ_INPUT_PARAM % (param.CONF_NAME)) def input_param(param): """ @@ -141,18 +141,18 @@ def input_param(param): """ # We need to check if a param needs confirmation, (i.e. ask user twice) # Do not validate if it was given from the command line - if (param.getKey("NEED_CONFIRM") and not commandLineValues.has_key(param.getKey("CONF_NAME"))): + if (param.NEED_CONFIRM and not commandLineValues.has_key(param.CONF_NAME)): #create a copy of the param so we can call it twice confirmedParam = copy.deepcopy(param) - confirmedParamName = param.getKey("CONF_NAME") + "_CONFIRMED" - confirmedParam.setKey("CONF_NAME", confirmedParamName) - confirmedParam.setKey("PROMPT", output_messages.INFO_CONF_PARAMS_PASSWD_CONFIRM_PROMPT) - confirmedParam.setKey("VALIDATORS", [validators.validate_not_empty]) + confirmedParamName = param.CONF_NAME + "_CONFIRMED" + confirmedParam.CONF_NAME = confirmedParamName + confirmedParam.PROMPT = output_messages.INFO_CONF_PARAMS_PASSWD_CONFIRM_PROMPT + confirmedParam.VALIDATORS = [validators.validate_not_empty] # Now get both values from user (with existing validations while True: _getInputFromUser(param) _getInputFromUser(confirmedParam) - if controller.CONF[param.getKey("CONF_NAME")] == controller.CONF[confirmedParamName]: + if controller.CONF[param.CONF_NAME] == controller.CONF[confirmedParamName]: logging.debug("Param confirmation passed, value for both questions is identical") break else: @@ -192,10 +192,10 @@ def _addDefaultsToMaskedValueSet(): """ global masked_value_set for group in controller.getAllGroups(): - for param in group.getAllParams(): + for param in group.parameters.itervalues(): # Keep default password values masked, but ignore default empty values - if ((param.getKey("MASK_INPUT") == True) and param.getKey("DEFAULT_VALUE") != ""): - masked_value_set.add(param.getKey("DEFAULT_VALUE")) + if ((param.MASK_INPUT == True) and param.DEFAULT_VALUE != ""): + masked_value_set.add(param.DEFAULT_VALUE) def _updateMaskedValueSet(): """ @@ -261,11 +261,11 @@ def maskString(str): return str def validate_param_value(param, value): - cname = param.getKey("CONF_NAME") + cname = param.CONF_NAME logging.debug("Validating parameter %s." % cname) - val_list = param.getKey("VALIDATORS") or [] - opt_list = param.getKey("OPTION_LIST") + val_list = param.VALIDATORS or [] + opt_list = param.OPTION_LIST for val_func in val_list: try: val_func(value, opt_list) @@ -275,10 +275,10 @@ def validate_param_value(param, value): def process_param_value(param, value): _value = value - processors = param.getKey("PROCESSORS") or [] + processors = param.PROCESSORS or [] for proc_func in processors: logging.debug("Processing value of parameter " - "%s." % param.getKey("CONF_NAME")) + "%s." % param.CONF_NAME) try: new_value = proc_func(_value, controller.CONF) if new_value != _value: @@ -290,7 +290,7 @@ def process_param_value(param, value): "value: %s" % _value) except processors.ParamProcessingError, ex: print ("Value processing of parameter %s " - "failed.\n%s" % (param.getKey("CONF_NAME"), ex)) + "failed.\n%s" % (param.CONF_NAME, ex)) raise return _value @@ -336,7 +336,7 @@ def _loadParamFromFile(config, section, paramName): validate_param_value(param, value) # Keep param value in our never ending global conf - controller.CONF[param.getKey("CONF_NAME")] = value + controller.CONF[param.CONF_NAME] = value return value @@ -359,32 +359,32 @@ def _handleAnswerFileParams(answerFile): # Handle pre conditions for group preConditionValue = True - if group.getKey("PRE_CONDITION"): - preConditionValue = _handleGroupCondition(fconf, group.getKey("PRE_CONDITION"), preConditionValue) + if group.PRE_CONDITION: + preConditionValue = _handleGroupCondition(fconf, group.PRE_CONDITION, preConditionValue) # Handle pre condition match with case insensitive values - logging.info("Comparing pre- conditions, value: '%s', and match: '%s'" % (preConditionValue, group.getKey("PRE_CONDITION_MATCH"))) - if utils.compareStrIgnoreCase(preConditionValue, group.getKey("PRE_CONDITION_MATCH")): - for param in group.getAllParams(): - _loadParamFromFile(fconf, "general", param.getKey("CONF_NAME")) + logging.info("Comparing pre- conditions, value: '%s', and match: '%s'" % (preConditionValue, group.PRE_CONDITION_MATCH)) + if utils.compareStrIgnoreCase(preConditionValue, group.PRE_CONDITION_MATCH): + for param in group.parameters.itervalues(): + _loadParamFromFile(fconf, "general", param.CONF_NAME) # Handle post conditions for group only if pre condition passed postConditionValue = True - if group.getKey("POST_CONDITION"): - postConditionValue = _handleGroupCondition(fconf, group.getKey("POST_CONDITION"), postConditionValue) + if group.POST_CONDITION: + postConditionValue = _handleGroupCondition(fconf, group.POST_CONDITION, postConditionValue) # Handle post condition match for group - if not utils.compareStrIgnoreCase(postConditionValue, group.getKey("POST_CONDITION_MATCH")): + if not utils.compareStrIgnoreCase(postConditionValue, group.POST_CONDITION_MATCH): logging.error("The group condition (%s) returned: %s, which differs from the excpeted output: %s"%\ - (group.getKey("GROUP_NAME"), postConditionValue, group.getKey("POST_CONDITION_MATCH"))) + (group.GROUP_NAME, postConditionValue, group.POST_CONDITION_MATCH)) raise ValueError(output_messages.ERR_EXP_GROUP_VALIDATION_ANS_FILE%\ - (group.getKey("GROUP_NAME"), postConditionValue, group.getKey("POST_CONDITION_MATCH"))) + (group.GROUP_NAME, postConditionValue, group.POST_CONDITION_MATCH)) else: - logging.debug("condition (%s) passed" % group.getKey("POST_CONDITION")) + logging.debug("condition (%s) passed" % group.POST_CONDITION) else: - logging.debug("no post condition check for group %s" % group.getKey("GROUP_NAME")) + logging.debug("no post condition check for group %s" % group.GROUP_NAME) else: - logging.debug("skipping params group %s since value of group validation is %s" % (group.getKey("GROUP_NAME"), preConditionValue)) + logging.debug("skipping params group %s since value of group validation is %s" % (group.GROUP_NAME, preConditionValue)) except Exception as e: logging.error(traceback.format_exc()) @@ -408,24 +408,24 @@ def _getanswerfilepath(): def _handleInteractiveParams(): try: - logging.debug("Groups: %s" % ', '.join([x.getKey("GROUP_NAME") for x in controller.getAllGroups()])) + logging.debug("Groups: %s" % ', '.join([x.GROUP_NAME for x in controller.getAllGroups()])) for group in controller.getAllGroups(): preConditionValue = True - logging.debug("going over group %s" % group.getKey("GROUP_NAME")) + logging.debug("going over group %s" % group.GROUP_NAME) # If pre_condition is set, get Value - if group.getKey("PRE_CONDITION"): - preConditionValue = _getConditionValue(group.getKey("PRE_CONDITION")) + if group.PRE_CONDITION: + preConditionValue = _getConditionValue(group.PRE_CONDITION) inputLoop = True # If we have a match, i.e. condition returned True, go over all params in the group - logging.info("Comparing pre-conditions; condition: '%s', and match: '%s'" % (preConditionValue, group.getKey("PRE_CONDITION_MATCH"))) - if utils.compareStrIgnoreCase(preConditionValue, group.getKey("PRE_CONDITION_MATCH")): + logging.info("Comparing pre-conditions; condition: '%s', and match: '%s'" % (preConditionValue, group.PRE_CONDITION_MATCH)) + if utils.compareStrIgnoreCase(preConditionValue, group.PRE_CONDITION_MATCH): while inputLoop: - for param in group.getAllParams(): - if not param.getKey("CONDITION"): + for param in group.parameters.itervalues(): + if not param.CONDITION: input_param(param) #update password list, so we know to mask them _updateMaskedValueSet() @@ -434,23 +434,23 @@ def _handleInteractiveParams(): # If group has a post condition, we check it after we get the input from # all the params in the group. if the condition returns False, we loop over the group again - if group.getKey("POST_CONDITION"): - postConditionValue = _getConditionValue(group.getKey("POST_CONDITION")) + if group.POST_CONDITION: + postConditionValue = _getConditionValue(group.POST_CONDITION) - if postConditionValue == group.getKey("POST_CONDITION_MATCH"): + if postConditionValue == group.POST_CONDITION_MATCH: inputLoop = False else: #we clear the value of all params in the group #in order to re-input them by the user - for param in group.getAllParams(): - if controller.CONF.has_key(param.getKey("CONF_NAME")): - del controller.CONF[param.getKey("CONF_NAME")] - if commandLineValues.has_key(param.getKey("CONF_NAME")): - del commandLineValues[param.getKey("CONF_NAME")] + for param in group.parameters.itervalues(): + if controller.CONF.has_key(param.CONF_NAME): + del controller.CONF[param.CONF_NAME] + if commandLineValues.has_key(param.CONF_NAME): + del commandLineValues[param.CONF_NAME] else: inputLoop = False else: - logging.debug("no post condition check for group %s" % group.getKey("GROUP_NAME")) + logging.debug("no post condition check for group %s" % group.GROUP_NAME) path = _getanswerfilepath() @@ -498,34 +498,34 @@ def _displaySummary(): print "=" * (len(output_messages.INFO_DSPLY_PARAMS) - 1) logging.info("*** User input summary ***") for group in controller.getAllGroups(): - for param in group.getAllParams(): - if not param.getKey("USE_DEFAULT") and controller.CONF.has_key(param.getKey("CONF_NAME")): - cmdOption = param.getKey("CMD_OPTION") + for param in group.parameters.itervalues(): + if not param.USE_DEFAULT and controller.CONF.has_key(param.CONF_NAME): + cmdOption = param.CMD_OPTION l = 30 - len(cmdOption) - maskParam = param.getKey("MASK_INPUT") + maskParam = param.MASK_INPUT # Only call mask on a value if the param has MASK_INPUT set to True if maskParam: - logging.info("%s: %s" % (cmdOption, mask(controller.CONF[param.getKey("CONF_NAME")]))) - print "%s:" % (cmdOption) + " " * l + mask(controller.CONF[param.getKey("CONF_NAME")]) + logging.info("%s: %s" % (cmdOption, mask(controller.CONF[param.CONF_NAME]))) + print "%s:" % (cmdOption) + " " * l + mask(controller.CONF[param.CONF_NAME]) else: # Otherwise, log & display it as it is - logging.info("%s: %s" % (cmdOption, str(controller.CONF[param.getKey("CONF_NAME")]))) - print "%s:" % (cmdOption) + " " * l + str(controller.CONF[param.getKey("CONF_NAME")]) + logging.info("%s: %s" % (cmdOption, str(controller.CONF[param.CONF_NAME]))) + print "%s:" % (cmdOption) + " " * l + str(controller.CONF[param.CONF_NAME]) logging.info("*** User input summary ***") answer = _askYesNo(output_messages.INFO_USE_PARAMS) if not answer: logging.debug("user chose to re-enter the user parameters") for group in controller.getAllGroups(): - for param in group.getAllParams(): - if controller.CONF.has_key(param.getKey("CONF_NAME")): - if not param.getKey("MASK_INPUT"): - param.setKey("DEFAULT_VALUE", controller.CONF[param.getKey("CONF_NAME")]) + for param in group.parameters.itervalues(): + if controller.CONF.has_key(param.CONF_NAME): + if not param.MASK_INPUT: + param.DEFAULT_VALUE = controller.CONF[param.CONF_NAME] # Remove the string from mask_value_set in order # to remove values that might be over overwritten. - removeMaskString(controller.CONF[param.getKey("CONF_NAME")]) - del controller.CONF[param.getKey("CONF_NAME")] - if commandLineValues.has_key(param.getKey("CONF_NAME")): - del commandLineValues[param.getKey("CONF_NAME")] + removeMaskString(controller.CONF[param.CONF_NAME]) + del controller.CONF[param.CONF_NAME] + if commandLineValues.has_key(param.CONF_NAME): + del commandLineValues[param.CONF_NAME] print "" logging.debug("calling handleParams in interactive mode") return _handleParams(None) @@ -564,10 +564,10 @@ def _summaryParamsToLog(): if len(controller.CONF) > 0: logging.debug("*** The following params were used as user input:") for group in controller.getAllGroups(): - for param in group.getAllParams(): - if controller.CONF.has_key(param.getKey("CONF_NAME")): - maskedValue = mask(controller.CONF[param.getKey("CONF_NAME")]) - logging.debug("%s: %s" % (param.getKey("CMD_OPTION"), maskedValue )) + for param in group.parameters.itervalues(): + if controller.CONF.has_key(param.CONF_NAME): + maskedValue = mask(controller.CONF[param.CONF_NAME]) + logging.debug("%s: %s" % (param.CMD_OPTION, maskedValue )) def runSequences(): @@ -642,19 +642,19 @@ def generateAnswerFile(outputFile, overrides={}): with os.fdopen(fd, "w") as ans_file: ans_file.write("[general]%s" % os.linesep) for group in controller.getAllGroups(): - for param in group.getAllParams(): - comm = param.getKey("USAGE") or '' + for param in group.parameters.itervalues(): + comm = param.USAGE or '' comm = textwrap.fill(comm, initial_indent='%s# ' % sep, subsequent_indent='# ', break_long_words=False) - value = controller.CONF.get(param.getKey("CONF_NAME"), - param.getKey("DEFAULT_VALUE")) + value = controller.CONF.get(param.CONF_NAME, + param.DEFAULT_VALUE) args = {'comment': comm, 'separator': sep, - 'default_value': overrides.get(param.getKey("CONF_NAME"), value), - 'conf_name': param.getKey("CONF_NAME")} + 'default_value': overrides.get(param.CONF_NAME, value), + 'conf_name': param.CONF_NAME} ans_file.write(fmt % args) def single_step_aio_install(options): @@ -691,10 +691,10 @@ def single_step_install(options): hosts = options.install_hosts hosts = [host.strip() for host in hosts.split(',')] for group in controller.getAllGroups(): - for param in group.getAllParams(): + for param in group.parameters.itervalues(): # and directives that contain _HOST are set to the controller node - if param.getKey("CONF_NAME").find("_HOST") != -1: - overrides[param.getKey("CONF_NAME")] = hosts[0] + if param.CONF_NAME.find("_HOST") != -1: + overrides[param.CONF_NAME] = hosts[0] # If there are more than one host, all but the first are a compute nodes if len(hosts) > 1: overrides["CONFIG_NOVA_COMPUTE_HOSTS"] = ','.join(hosts[1:]) @@ -733,13 +733,13 @@ def initCmdLineParser(): # For each group, create a group option for group in controller.getAllGroups(): - groupParser = OptionGroup(parser, group.getKey("DESCRIPTION")) + groupParser = OptionGroup(parser, group.DESCRIPTION) - for param in group.getAllParams(): - cmdOption = param.getKey("CMD_OPTION") - paramUsage = param.getKey("USAGE") - optionsList = param.getKey("OPTION_LIST") - useDefault = param.getKey("USE_DEFAULT") + for param in group.parameters.itervalues(): + cmdOption = param.CMD_OPTION + paramUsage = param.USAGE + optionsList = param.OPTION_LIST + useDefault = param.USE_DEFAULT if not useDefault: groupParser.add_option("--%s" % cmdOption, help=paramUsage) @@ -756,14 +756,14 @@ def printOptions(): # For each group, create a group option for group in controller.getAllGroups(): - print "%s"%group.getKey("DESCRIPTION") - print "-"*len(group.getKey("DESCRIPTION")) + print "%s" % group.DESCRIPTION + print "-" * len(group.DESCRIPTION) print - for param in group.getAllParams(): - cmdOption = param.getKey("CONF_NAME") - paramUsage = param.getKey("USAGE") - optionsList = param.getKey("OPTION_LIST") or "" + for param in group.parameters.itervalues(): + cmdOption = param.CONF_NAME + paramUsage = param.USAGE + optionsList = param.OPTION_LIST or "" print "%s : %s %s"%(("**%s**"%str(cmdOption)).ljust(30), paramUsage, optionsList) print @@ -849,9 +849,9 @@ def _set_command_line_values(options): for key, value in options.__dict__.items(): # Replace the _ with - in the string since optparse replace _ with - for group in controller.getAllGroups(): - param = group.getParams("CMD_OPTION", key.replace("_","-")) + param = group.search("CMD_OPTION", key.replace("_","-")) if len(param) > 0 and value: - commandLineValues[param[0].getKey("CONF_NAME")] = value + commandLineValues[param[0].CONF_NAME] = value def main(): try: diff --git a/packstack/installer/setup_controller.py b/packstack/installer/setup_controller.py index e6b3e294d..bd4884fb7 100644 --- a/packstack/installer/setup_controller.py +++ b/packstack/installer/setup_controller.py @@ -105,7 +105,7 @@ class Controller(object): def getGroupByName(self, groupName): for group in self.getAllGroups(): - if group.getKey("GROUP_NAME") == groupName: + if group.GROUP_NAME == groupName: return group return None @@ -114,7 +114,7 @@ class Controller(object): def __getGroupIndexByDesc(self, name): for group in self.getAllGroups(): - if group.getKey("GROUP_NAME") == name: + if group.GROUP_NAME == name: return self.__GROUPS.index(group) return None @@ -131,14 +131,13 @@ class Controller(object): def getParamByName(self, paramName): for group in self.getAllGroups(): - param = group.getParamByName(paramName) - if param: - return param + if paramName in group.parameters: + return group.parameters[paramName] return None def getParamKeyValue(self, paramName, keyName): param = self.getParamByName(paramName) if param: - return param.getKey(keyName) + return getattr(param, keyName) else: return None diff --git a/packstack/installer/setup_params.py b/packstack/installer/setup_params.py index 5a92ff5cf..90cf55cf3 100644 --- a/packstack/installer/setup_params.py +++ b/packstack/installer/setup_params.py @@ -1,86 +1,47 @@ +# -*- coding: utf-8 -*- + """ Container set for groups and parameters """ -class Param(object): - allowed_keys = ('CMD_OPTION','USAGE','PROMPT','OPTION_LIST', - 'PROCESSORS', 'VALIDATORS','DEFAULT_VALUE', - 'MASK_INPUT', 'LOOSE_VALIDATION', 'CONF_NAME', - 'USE_DEFAULT','NEED_CONFIRM','CONDITION') + +from .datastructures import SortedDict + + +class Parameter(object): + allowed_keys = ('CONF_NAME', 'CMD_OPTION', 'USAGE', 'PROMPT', + 'PROCESSORS', 'VALIDATORS', 'LOOSE_VALIDATION', + 'DEFAULT_VALUE', 'USE_DEFAULT', 'OPTION_LIST', + 'MASK_INPUT', 'NEED_CONFIRM','CONDITION') def __init__(self, attributes=None): - if not attributes: - self.__ATTRIBUTES = {}.fromkeys(self.allowed_keys) - return + attributes = attributes or {} + defaults = {}.fromkeys(self.allowed_keys) + defaults.update(attributes) - self.__ATTRIBUTES = {} - for key, value in attributes.iteritems(): + for key, value in defaults.iteritems(): if key not in self.allowed_keys: - raise KeyError('Given attribute %s is ' - 'not allowed' % key) - self.__ATTRIBUTES[key] = value + raise KeyError('Given attribute %s is not allowed' % key) + self.__dict__[key] = value - def setKey(self, key, value): - self.validateKey(key) - self.__ATTRIBUTES[key] = value - def getKey(self, key): - self.validateKey(key) - return self.__ATTRIBUTES.get(key) +class Group(Parameter): + allowed_keys = ('GROUP_NAME', 'DESCRIPTION', 'PRE_CONDITION', + 'PRE_CONDITION_MATCH', 'POST_CONDITION', + 'POST_CONDITION_MATCH') - def validateKey(self, key): - if key not in self.allowed_keys: - raise KeyError("%s is not a valid key" % key) + def __init__(self, attributes=None, parameters=None): + super(Group, self).__init__(attributes) + self.parameters = SortedDict() + for param in parameters or []: + self.parameters[param['CONF_NAME']] = Parameter(attributes=param) -class Group(Param): - allowed_keys = ('GROUP_NAME', 'DESCRIPTION', 'PRE_CONDITION', 'PRE_CONDITION_MATCH', 'POST_CONDITION', 'POST_CONDITION_MATCH') - - def __init__(self, attributes={}, params=[]): - self.__PARAMS = [] - Param.__init__(self, attributes) - for param in params: - self.addParam(param) - - def addParam(self,paramDict): - p = Param(paramDict) - self.__PARAMS.append(p) - - def getParamByName(self,paramName): - for param in self.__PARAMS: - if param.getKey("CONF_NAME") == paramName: - return param - return None - - def getAllParams(self): - return self.__PARAMS - - def getParams(self,paramKey, paramValue): - output = [] - for param in self.__PARAMS: - if param.getKey(paramKey) == paramValue: - output.append(param) - return output - - def __getParamIndexByDesc(self, name): - for param in self.getAllParams(): - if param.getKey("CONF_NAME") == name: - return self.__PARAMS.index(param) - return None - - def insertParamBeforeParam(self, paramName, param): + def search(self, attr, value): """ - Insert a param before a named param. - i.e. if the specified param name is "update x", the new - param will be inserted BEFORE "update x" + Returns list of parameters which have given attribute of given + value. """ - index = self.__getParamIndexByDesc(paramName) - if index == None: - index = len(self.getAllParams()) - self.__PARAMS.insert(index, Param(param)) - - def removeParamByName(self, paramName): - self.__removeParams("CONF_NAME", paramName) - - def __removeParams(self, paramKey, paramValue): - list = self.getParams(paramKey, paramValue) - for item in list: - self.__PARAMS.remove(item) + result = [] + for param in self.parameters.itervalues(): + if getattr(param, attr) == value: + result.append(param) + return result diff --git a/tests/installer/test_setup_params.py b/tests/installer/test_setup_params.py new file mode 100644 index 000000000..7e32b287d --- /dev/null +++ b/tests/installer/test_setup_params.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +""" +Test cases for packstack.installer.setup_params module. +""" + +from unittest import TestCase + +from ..test_base import PackstackTestCaseMixin +from packstack.installer.setup_params import * + + +class ParameterTestCase(PackstackTestCaseMixin, TestCase): + def setUp(self): + super(ParameterTestCase, self).setUp() + self.data = { + "CMD_OPTION": "mysql-host", + "USAGE": ("The IP address of the server on which to " + "install MySQL"), + "PROMPT": "Enter the IP address of the MySQL server", + "OPTION_LIST": [], + "VALIDATORS": [], + "DEFAULT_VALUE": "127.0.0.1", + "MASK_INPUT": False, + "LOOSE_VALIDATION": True, + "CONF_NAME": "CONFIG_MYSQL_HOST", + "USE_DEFAULT": False, + "NEED_CONFIRM": False, + "CONDITION": False} + + def test_parameter_init(self): + """ + Test packstack.installer.setup_params.Parameter initialization + """ + param = Parameter(self.data) + for key, value in self.data.iteritems(): + self.assertEqual(getattr(param, key), value) + + def test_default_attribute(self): + """ + Test packstack.installer.setup_params.Parameter default value + """ + param = Parameter() + self.assertIsNone(param.PROCESSORS) + + +class GroupTestCase(PackstackTestCaseMixin, TestCase): + def setUp(self): + super(GroupTestCase, self).setUp() + self.attrs = { + "GROUP_NAME": "MYSQL", + "DESCRIPTION": "MySQL Config parameters", + "PRE_CONDITION": "y", + "PRE_CONDITION_MATCH": "y", + "POST_CONDITION": False, + "POST_CONDITION_MATCH": False} + self.params = [ + {"CONF_NAME": "CONFIG_MYSQL_HOST", "PROMPT": "find_me"}, + {"CONF_NAME": "CONFIG_MYSQL_USER"}, + {"CONF_NAME": "CONFIG_MYSQL_PW"}] + + def test_group_init(self): + """Test packstack.installer.setup_params.Group initialization""" + group = Group(attributes=self.attrs, parameters=self.params) + for key, value in self.attrs.iteritems(): + self.assertEqual(getattr(group, key), value) + for param in self.params: + self.assertIn(param['CONF_NAME'], group.parameters) + + def test_search(self): + """Test packstack.installer.setup_params.Group search method""" + group = Group(attributes=self.attrs, parameters=self.params) + param_list = group.search('PROMPT', 'find_me') + self.assertEqual(len(param_list), 1) + self.assertIsInstance(param_list[0], Parameter) + self.assertEqual(param_list[0].CONF_NAME, 'CONFIG_MYSQL_HOST') diff --git a/tests/test_base.py b/tests/test_base.py index 8b546c104..753b6b6de 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -86,3 +86,8 @@ class PackstackTestCaseMixin(object): _msg = msg or ('%s is not a member of %s' % (first, second)) if first not in second: raise AssertionError(_msg) + + def assertIsNone(self, expr, msg=None): + _msg = msg or ('%s is not None' % expr) + if expr is not None: + raise AssertionError(_msg)