import re import os.path import yaml from rubick.common import Issue, MarkedIssue, Mark, Version, find, index from rubick.exceptions import RubickException class SchemaError(RubickException): pass class ConfigSchemaRegistry: db_path = os.path.join(os.path.dirname(__file__), 'schemas') @classmethod def get_schema(self, project, version, configname=None): if not configname: configname = '%s.conf' % project fullname = '%s/%s' % (project, configname) version = Version(version) path = os.path.join(self.db_path, project, configname + '.yml') if not os.path.exists(path): return None with open(path) as f: records = yaml.load(f.read()) i = len(records) - 1 # Find latest checkpoint prior given version while i >= 0 and not (records[i].get('checkpoint', False) and Version(records[i]['version']) <= version): i -= 1 if i < 0: if Version(records[0]['version']) > version: # Reached the earliest record yet haven't found version return None # Haven't found checkpoint but yearliest version is less than given # Assuming first record is checkpoint i = 0 parameters = [] seen_parameters = set() last_version = None while i < len(records) and Version(records[i]['version']) <= version: last_version = records[i]['version'] for param_data in records[i].get('added', []): name = param_data['name'] section = None if '.' in name: section, name = name.split('.', 1) param = ConfigParameterSchema( name, param_data['type'], section=section, default=param_data.get('default', None), description=param_data.get('help', None), required=param_data.get('required', False), deprecation_message=param_data.get('deprecated', None)) if param.name in seen_parameters: old_param_index = index( parameters, lambda p: p.name == param.name) if old_param_index != -1: parameters[old_param_index] = param else: parameters.append(param) seen_parameters.add(param.name) for param_name in records[i].get('removed', []): param_index = index( parameters, lambda p: p.name == param_name) if index != -1: parameters.pop(param_index) seen_parameters.discard(param_name) i += 1 return ConfigSchema(fullname, last_version, 'ini', parameters) class ConfigSchema: def __init__(self, name, version, format, parameters): self.name = name self.version = Version(version) self.format = format self.parameters = parameters def has_section(self, section): return ( find(self.parameters, lambda p: p.section == section) is not None ) def get_parameter(self, name, section=None): # TODO: optimize this return ( find( self.parameters, lambda p: p.name == name and p.section == section) ) def __repr__(self): return ( '' % ( self.name, self.version, self.format, self.parameters) ) class ConfigParameterSchema: def __init__(self, name, type, section=None, description=None, default=None, required=False, deprecation_message=None): self.section = section self.name = name self.type = type self.description = description self.default = default self.required = required self.deprecation_message = deprecation_message def __repr__(self): return ( '' % ' '.join( ['%s=%s' % (attr, getattr(self, attr)) for attr in ['section', 'name', 'type', 'description', 'default', 'required']]) ) class TypeValidatorRegistry: __validators = {} __default_validator = None @classmethod def register_validator(self, type_name, type_validator, default=False): self.__validators[type_name] = type_validator if default: self.__default_validator = type_name @classmethod def get_validator(self, name): return self.__validators.get( name, self.__validators[self.__default_validator]) class SchemaIssue(Issue): def __init__(self, message): super(SchemaIssue, self).__init__(Issue.ERROR, message) class InvalidValueError(MarkedIssue): def __init__(self, message, mark=Mark('', 0, 0)): super(InvalidValueError, self).__init__( Issue.ERROR, 'Invalid value: ' + message, mark) class TypeValidator(object): def __init__(self, base_type, f): super(TypeValidator, self).__init__() self.base_type = base_type self.f = f def validate(self, value): if value is None: return value return getattr(self, 'f')(value) def type_validator(name, base_type=None, default=False, **kwargs): if not base_type: base_type = name def wrap(fn): def wrapped(s): return fn(s, **kwargs) o = TypeValidator(base_type, wrapped) TypeValidatorRegistry.register_validator(name, o, default=default) return fn return wrap def isissue(o): return isinstance(o, Issue) @type_validator('boolean') def validate_boolean(s): if isinstance(s, bool): return s s = s.lower() if s == 'true': return True elif s == 'false': return False else: return InvalidValueError('Value should be "true" or "false"') def validate_enum(s, values=[]): if s in values: return None if len(values) == 0: message = 'There should be no value' elif len(values) == 1: message = 'The only valid value is %s' % values[0] else: message = 'Valid values are %s and %s' % ( ', '.join(values[:-1]), values[-1]) return InvalidValueError('%s' % message) def validate_ipv4_address(s): s = s.strip() parts = s.split('.') if len(parts) == 4: if all([all([c.isdigit() for c in part]) for part in parts]): parts = [int(part) for part in parts] if all([part < 256 for part in parts]): return '.'.join([str(part) for part in parts]) return InvalidValueError('Value should be ipv4 address') def validate_ipv4_network(s): s = s.strip() parts = s.split('/') if len(parts) != 2: return ( InvalidValueError( 'Should have "/" character separating address and prefix ' 'length') ) address, prefix = parts prefix = prefix.strip() if prefix.strip() == '': return InvalidValueError('Prefix length is required') address = validate_ipv4_address(address) if isissue(address): return address if not all([c.isdigit() for c in prefix]): return InvalidValueError('Prefix length should be an integer') prefix = int(prefix) if prefix > 32: return ( InvalidValueError( 'Prefix length should be less than or equal to 32') ) return '%s/%d' % (address, prefix) def validate_host_label(s): if len(s) == 0: return InvalidValueError( 'Host label should have at least one character') if not s[0].isalpha(): return InvalidValueError( 'Host label should start with a letter, but it starts with ' '"%s"' % s[0]) if len(s) == 1: return s if not (s[-1].isdigit() or s[-1].isalpha()): return InvalidValueError( 'Host label should end with letter or digit, but it ends ' 'with "%s"' % s[-1], Mark('', 0, len(s) - 1)) if len(s) == 2: return s for i, c in enumerate(s[1:-1]): if not (c.isalpha() or c.isdigit() or c == '-'): return InvalidValueError( 'Host label should contain only letters, digits or hypens,' ' but it contains "%s"' % c, Mark('', 0, i + 1)) return s @type_validator('host', base_type='string') @type_validator('host_address', base_type='string') @type_validator('old_network', base_type='string') def validate_host_address(s): result = validate_ipv4_address(s) if not isissue(result): return result offset = len(s) - len(s.lstrip()) parts = s.strip().split('.') part_offset = offset labels = [] for part in parts: host_label = validate_host_label(part) if isissue(host_label): return host_label.offset_by(Mark('', 0, part_offset)) part_offset += len(part) + 1 labels.append(host_label) return '.'.join(labels) @type_validator('network', base_type='string') @type_validator('network_address', base_type='string') def validate_network_address(s): return validate_ipv4_network(s) @type_validator('network_mask', base_type='string') def validate_network_mask(s): # TODO: implement proper checking result = validate_ipv4_address(s) if isissue(result): return result parts = [int(p) for p in result.split('.', 3)] x = index(parts, lambda p: p != 255) if x == -1: return result if parts[x] not in [0, 128, 192, 224, 240, 248, 252, 254]: return InvalidValueError('Invalid netmask') x += 1 while x < 4: if parts[x] != 0: return InvalidValueError('Invalid netmask') return result @type_validator('host_and_port', base_type='string') def validate_host_and_port(s, default_port=None): parts = s.strip().split(':', 2) host_address = validate_host_address(parts[0]) if isissue(host_address): return host_address if len(parts) == 2: port = validate_port(parts[1]) if isissue(port): return port elif default_port: port = default_port else: return InvalidValueError('No port specified') return (host_address, port) @type_validator('string', base_type='string', default=True) @type_validator('list', base_type='list') @type_validator('multi', base_type='multi') @type_validator('file', base_type='string') @type_validator('directory', base_type='string') @type_validator('regex', base_type='string') @type_validator('host_v6', base_type='string') def validate_string(s): return s @type_validator('integer') def validate_integer(s, min=None, max=None): if isinstance(s, int): return s leading_whitespace_len = 0 while leading_whitespace_len < len(s) \ and s[leading_whitespace_len].isspace(): leading_whitespace_len += 1 s = s.strip() if s == '': return InvalidValueError('Should not be empty') for i, c in enumerate(s): if not c.isdigit() and not ((c == '-') and (i == 0)): return ( InvalidValueError( 'Only digits are allowed, but found char "%s"' % c, Mark('', 1, i + 1 + leading_whitespace_len)) ) v = int(s) if min and v < min: return ( InvalidValueError( 'Should be greater than or equal to %d' % min, Mark('', 1, leading_whitespace_len)) ) if max and v > max: return ( InvalidValueError( 'Should be less than or equal to %d' % max, Mark('', 1, leading_whitespace_len)) ) return v @type_validator('file_mode') def validate_file_mode(s): return validate_integer(s) @type_validator('float') def validate_float(s): if isinstance(s, float): return s # TODO: Implement proper validation return float(s) @type_validator('port', base_type='integer') def validate_port(s, min=1, max=65535): return validate_integer(s, min=min, max=max) def validate_list(s, element_type): if isinstance(s, list): return s element_type_validator = TypeValidatorRegistry.get_validator(element_type) if not element_type_validator: return SchemaIssue('Invalid element type "%s"' % element_type) result = [] s = s.strip() if s == '': return result values = s.split(',') while len(values) > 0: value = values.pop(0) while True: validated_value = element_type_validator.validate(value.strip()) if not isinstance(validated_value, Issue): break if len(values) == 0: # TODO: provide better position reporting return validated_value value += ',' + values.pop() result.append(validated_value) return result @type_validator('string_list', base_type='list') def validate_string_list(s): return validate_list(s, element_type='string') @type_validator('string_dict', base_type='multi') def validate_dict(s, element_type='string'): if isinstance(s, dict): return s element_type_validator = TypeValidatorRegistry.get_validator(element_type) if not element_type_validator: return SchemaIssue('Invalid element type "%s"' % element_type) result = {} s = s.strip() if s == '': return result pairs = s.split(',') for pair in pairs: key_value = pair.split(':', 2) if len(key_value) < 2: return ( InvalidValueError( 'Value should be NAME:VALUE pairs separated by ","') ) key, value = key_value key = key.strip() value = value.strip() if key == '': # TODO: provide better position reporting return InvalidValueError('Key name should not be empty') validated_value = element_type_validator.validate(value) if isinstance(validated_value, Issue): # TODO: provide better position reporting return validated_value result[key] = validated_value return result @type_validator('rabbitmq_bind', base_type='string') def validate_rabbitmq_bind(s): m = re.match('\d+', s) if m: port = validate_port(s) if isinstance(port, Issue): return port return ('0.0.0.0', port) m = re.match('{\s*\"(.+)\"\s*,\s*(\d+)\s*}', s) if m: host = validate_host_address(m.group(1)) port = validate_port(m.group(2)) if isinstance(host, Issue): return host if isinstance(port, Issue): return port return (host, port) return SchemaIssue("Unrecognized bind format") def validate_rabbitmq_list(s, element_type): if isinstance(s, list): return s if not (s.startswith('[') and s.endswith(']')): return SchemaIssue('List should be surrounded by [ and ]') return validate_list(s[1:-1], element_type=element_type) @type_validator('rabbitmq_bind_list', base_type='list') def validate_rabbitmq_bind_list(s): return validate_rabbitmq_list(s, element_type='rabbitmq_bind')