406 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			406 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import sys
 | 
						|
 | 
						|
from ostack_validator.common import Inspection, Issue, MarkedIssue, Mark, Version, find, index
 | 
						|
 | 
						|
class SchemaUpdateRecord(object):
 | 
						|
  # checkpoint's data is version number
 | 
						|
  def __init__(self, version, operation, data=None):
 | 
						|
    super(SchemaUpdateRecord, self).__init__()
 | 
						|
    if not operation in ['checkpoint', 'add', 'remove']:
 | 
						|
      raise Error, 'Unknown operation "%s"' % operation
 | 
						|
    version = Version(version)
 | 
						|
    self.version = version
 | 
						|
    self.operation = operation
 | 
						|
    self.data = data
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<SchemaUpdateRecord %s %s %s' % (self.version, self.operation, self.data)
 | 
						|
 | 
						|
class SchemaBuilder(object):
 | 
						|
  def __init__(self, name, data):
 | 
						|
    super(SchemaBuilder, self).__init__()
 | 
						|
    self.name = name
 | 
						|
    self.data = data
 | 
						|
 | 
						|
    self.current_version = None
 | 
						|
    self.current_section = None
 | 
						|
    self.adds = []
 | 
						|
    self.removals = []
 | 
						|
 | 
						|
  def __del__(self):
 | 
						|
    if len(self.adds) > 0 or len(self.removals) > 0:
 | 
						|
      sys.stderr.write("WARNING: Uncommitted config schema \"%s\" version %s\n" % (self.name, self.current_version))
 | 
						|
 | 
						|
  def version(self, version, checkpoint=False):
 | 
						|
    version = Version(version)
 | 
						|
 | 
						|
    if self.current_version and self.current_version != version:
 | 
						|
      self.commit()
 | 
						|
 | 
						|
    if checkpoint or self.data == []:
 | 
						|
      self.data.append(SchemaUpdateRecord(version, 'checkpoint'))
 | 
						|
 | 
						|
    self.current_version = version
 | 
						|
 | 
						|
  def section(self, name):
 | 
						|
    self.current_section = name
 | 
						|
 | 
						|
  def param(self, *args, **kwargs):
 | 
						|
    self._ensure_version()
 | 
						|
 | 
						|
    if not 'section' in kwargs and self.current_section:
 | 
						|
      kwargs['section'] = self.current_section
 | 
						|
 | 
						|
    self.adds.append(ConfigParameterSchema(*args, **kwargs))
 | 
						|
 | 
						|
  def remove_param(self, name):
 | 
						|
    self._ensure_version()
 | 
						|
 | 
						|
    self.removals.append(name)
 | 
						|
 | 
						|
  def commit(self):
 | 
						|
    "Finalize schema building"
 | 
						|
    self._ensure_version()
 | 
						|
 | 
						|
    if len(self.removals) > 0:
 | 
						|
      self.data.append(SchemaUpdateRecord(self.current_version, 'remove', self.removals))
 | 
						|
      self.removals = []
 | 
						|
    if len(self.adds) > 0:
 | 
						|
      self.data.append(SchemaUpdateRecord(self.current_version, 'add', self.adds))
 | 
						|
      self.adds = []
 | 
						|
 | 
						|
  def _ensure_version(self):
 | 
						|
    if not self.current_version:
 | 
						|
      raise Error, 'Schema version is not specified. Please call version() method first'
 | 
						|
 | 
						|
class ConfigSchemaRegistry:
 | 
						|
  __schemas = {}
 | 
						|
  @classmethod
 | 
						|
  def register_schema(self, project, configname=None):
 | 
						|
    if not configname:
 | 
						|
      configname = '%s.conf' % project
 | 
						|
    fullname = '%s/%s' % (project, configname)
 | 
						|
    self.__schemas[fullname] = []
 | 
						|
    return SchemaBuilder(fullname, self.__schemas[fullname])
 | 
						|
 | 
						|
  @classmethod
 | 
						|
  def get_schema(self, project, version, configname=None):
 | 
						|
    if not configname:
 | 
						|
      configname = '%s.conf' % project
 | 
						|
    fullname = '%s/%s' % (project, configname)
 | 
						|
    version = Version(version)
 | 
						|
 | 
						|
    if not fullname in self.__schemas:
 | 
						|
      return None
 | 
						|
 | 
						|
    records = self.__schemas[fullname]
 | 
						|
    i = len(records)-1
 | 
						|
    # Find latest checkpoint prior given version
 | 
						|
    while i>=0 and not (records[i].operation == 'checkpoint' and records[i].version <= version): i-=1
 | 
						|
 | 
						|
    if i < 0:
 | 
						|
      return None
 | 
						|
 | 
						|
    parameters = []
 | 
						|
    seen_parameters = set()
 | 
						|
    last_version = None
 | 
						|
 | 
						|
    while i < len(records) and records[i].version <= version:
 | 
						|
      last_version = records[i].version
 | 
						|
      if records[i].operation == 'add':
 | 
						|
        for param in records[i].data:
 | 
						|
          if param.name in seen_parameters:
 | 
						|
            old_param_index = index(parameters, lambda p: p.name == param.name)
 | 
						|
            if old_param_index != -1:
 | 
						|
              parameters[old_param_index] = param
 | 
						|
          else:
 | 
						|
            parameters.append(param)
 | 
						|
            seen_parameters.add(param.name)
 | 
						|
      elif records[i].operation == 'remove':
 | 
						|
        for param_name in records[i].data:
 | 
						|
          param_index = index(parameters, lambda p: p.name == param_name)
 | 
						|
          if index != -1:
 | 
						|
            parameters.pop(param_index)
 | 
						|
            seen_parameters.remove(param_name)
 | 
						|
      i += 1
 | 
						|
 | 
						|
    return ConfigSchema(fullname, last_version, 'ini', parameters)
 | 
						|
 | 
						|
 | 
						|
class ConfigSchema:
 | 
						|
  def __init__(self, name, version, format, parameters):
 | 
						|
    self.name = name
 | 
						|
    self.version = Version(version)
 | 
						|
    self.format = format
 | 
						|
    self.parameters = parameters
 | 
						|
 | 
						|
  def get_parameter(self, name, section=None):
 | 
						|
    # TODO: optimize this
 | 
						|
    return find(self.parameters, lambda p: p.name == name and p.section == section)
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<ConfigSchema name=%s version=%s format=%s parameters=%s>' % (self.name, self.version, self.format, self.parameters)
 | 
						|
    
 | 
						|
class ConfigParameterSchema:
 | 
						|
  def __init__(self, name, type, section=None, description=None, default=None, required=False):
 | 
						|
    self.section = section
 | 
						|
    self.name = name
 | 
						|
    self.type = type
 | 
						|
    self.description = description
 | 
						|
    self.default = default
 | 
						|
    self.required = required
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<ConfigParameterSchema %s>' % ' '.join(['%s=%s' % (attr, getattr(self, attr)) for attr in ['section', 'name', 'type', 'description', 'default', 'required']])
 | 
						|
 | 
						|
 | 
						|
class TypeValidatorRegistry:
 | 
						|
  __validators = {}
 | 
						|
  @classmethod
 | 
						|
  def register_validator(self, type_name, type_validator):
 | 
						|
    self.__validators[type_name] = type_validator
 | 
						|
 | 
						|
  @classmethod
 | 
						|
  def get_validator(self, name):
 | 
						|
    return self.__validators[name]
 | 
						|
 | 
						|
 | 
						|
class SchemaError(Issue):
 | 
						|
  def __init__(self, message):
 | 
						|
    super(SchemaError, self).__init__(Issue.ERROR, message)
 | 
						|
 | 
						|
class InvalidValueError(MarkedIssue):
 | 
						|
  def __init__(self, message, mark=Mark('', 0, 0)):
 | 
						|
    super(InvalidValueError, self).__init__(Issue.ERROR, 'Invalid value: '+message, mark)
 | 
						|
 | 
						|
class TypeValidator(object):
 | 
						|
  def __init__(self, f):
 | 
						|
    super(TypeValidator, self).__init__()
 | 
						|
    self.f = f
 | 
						|
 | 
						|
  def validate(self, value):
 | 
						|
    return getattr(self, 'f')(value)
 | 
						|
 | 
						|
 | 
						|
def type_validator(name, **kwargs):
 | 
						|
  def wrap(fn):
 | 
						|
    def wrapped(s):
 | 
						|
      return fn(s, **kwargs)
 | 
						|
    o = TypeValidator(wrapped)
 | 
						|
    TypeValidatorRegistry.register_validator(name, o)
 | 
						|
    return fn
 | 
						|
  return wrap
 | 
						|
 | 
						|
def isissue(o):
 | 
						|
  return isinstance(o, Issue)
 | 
						|
 | 
						|
@type_validator('boolean')
 | 
						|
def validate_boolean(s):
 | 
						|
  s = s.lower()
 | 
						|
  if s == 'true':
 | 
						|
    return True
 | 
						|
  elif s == 'false':
 | 
						|
    return False
 | 
						|
  else:
 | 
						|
    return InvalidValueError('Value should be "true" or "false"')
 | 
						|
 | 
						|
def validate_enum(s, values=[]):
 | 
						|
  if s in values:
 | 
						|
    return None
 | 
						|
  if len(values) == 0:
 | 
						|
    message = 'There should be no value'
 | 
						|
  elif len(values) == 1:
 | 
						|
    message = 'The only valid value is %s' % values[0]
 | 
						|
  else:
 | 
						|
    message = 'Valid values are %s and %s' % (', '.join(values[:-1]), values[-1])
 | 
						|
  return InvalidValueError('%s' % message)
 | 
						|
 | 
						|
def validate_ipv4_address(s):
 | 
						|
  s = s.strip()
 | 
						|
  parts = s.split('.')
 | 
						|
  if len(parts) == 4:
 | 
						|
    if all([all([c.isdigit() for c in part]) for part in parts]):
 | 
						|
      parts = [int(part) for part in parts]
 | 
						|
      if all([part < 256 for part in parts]):
 | 
						|
        return '.'.join([str(part) for part in parts])
 | 
						|
 | 
						|
  return InvalidValueError('Value should be ipv4 address')
 | 
						|
 | 
						|
def validate_ipv4_network(s):
 | 
						|
  s = s.strip()
 | 
						|
  parts = s.split('/')
 | 
						|
  if len(parts) != 2:
 | 
						|
    return InvalidValueError('Should have "/" character separating address and prefix length')
 | 
						|
 | 
						|
  address, prefix = parts
 | 
						|
  prefix = prefix.strip()
 | 
						|
 | 
						|
  if prefix.strip() == '':
 | 
						|
    return InvalidValueError('Prefix length is required')
 | 
						|
 | 
						|
  address = validate_ipv4_address(address)
 | 
						|
  if isissue(address):
 | 
						|
    return address
 | 
						|
 | 
						|
  if not all([c.isdigit() for c in prefix]):
 | 
						|
    return InvalidValueError('Prefix length should be an integer')
 | 
						|
 | 
						|
  prefix = int(prefix)
 | 
						|
  if prefix > 32:
 | 
						|
    return InvalidValueError('Prefix length should be less than or equal to 32')
 | 
						|
 | 
						|
  return '%s/%d' % (address, prefix)
 | 
						|
 | 
						|
def validate_host_label(s):
 | 
						|
  if len(s) == 0:
 | 
						|
    return InvalidValueError('Host label should have at least one character')
 | 
						|
 | 
						|
  if not s[0].isalpha():
 | 
						|
    return InvalidValueError('Host label should start with a letter, but it starts with "%s"' % s[0])
 | 
						|
 | 
						|
  if len(s) == 1: return s
 | 
						|
 | 
						|
  if not (s[-1].isdigit() or s[-1].isalpha()):
 | 
						|
    return InvalidValueError('Host label should end with letter or digit, but it ends with "%s"' % s[-1], Mark('', 0, len(s)-1))
 | 
						|
 | 
						|
  if len(s) == 2: return s
 | 
						|
 | 
						|
  for i, c in enumerate(s[1:-1]):
 | 
						|
    if not (c.isalpha() or c.isdigit() or c == '-'):
 | 
						|
      return InvalidValueError('Host label should contain only letters, digits or hypens, but it contains "%s"' % c, Mark('', 0, i+1))
 | 
						|
 | 
						|
  return s
 | 
						|
 | 
						|
 | 
						|
 | 
						|
@type_validator('host')
 | 
						|
@type_validator('host_address')
 | 
						|
def validate_host_address(s):
 | 
						|
  result = validate_ipv4_address(s)
 | 
						|
  if not isissue(result):
 | 
						|
    return result
 | 
						|
 | 
						|
  offset = len(s) - len(s.lstrip())
 | 
						|
 | 
						|
  parts = s.strip().split('.')
 | 
						|
  part_offset = offset
 | 
						|
  labels = []
 | 
						|
  for part in parts:
 | 
						|
    host_label = validate_host_label(part)
 | 
						|
    if isissue(host_label):
 | 
						|
      return host_label.offset_by(Mark('', 0, part_offset))
 | 
						|
 | 
						|
    part_offset += len(part)+1
 | 
						|
    labels.append(host_label)
 | 
						|
 | 
						|
  return '.'.join(labels)
 | 
						|
 | 
						|
@type_validator('network')
 | 
						|
@type_validator('network_address')
 | 
						|
def validate_network_address(s):
 | 
						|
  return validate_ipv4_network(s)
 | 
						|
 | 
						|
@type_validator('host_and_port')
 | 
						|
def validate_host_and_port(s, default_port=None):
 | 
						|
  parts = s.strip().split(':', 2)
 | 
						|
 | 
						|
  host_address = validate_host_address(parts[0])
 | 
						|
  if isissue(host_address):
 | 
						|
    return host_address
 | 
						|
 | 
						|
  if len(parts) == 2:
 | 
						|
    port = validate_port(parts[1])
 | 
						|
    if isissue(port):
 | 
						|
      return port
 | 
						|
  elif default_port:
 | 
						|
    port = default_port
 | 
						|
  else:
 | 
						|
    return InvalidValueError('No port specified')
 | 
						|
 | 
						|
  return (host_address, port)
 | 
						|
 | 
						|
@type_validator('string')
 | 
						|
def validate_string(s):
 | 
						|
  return s
 | 
						|
 | 
						|
@type_validator('integer')
 | 
						|
def validate_integer(s, min=None, max=None):
 | 
						|
  leading_whitespace_len = 0
 | 
						|
  while leading_whitespace_len < len(s) and s[leading_whitespace_len].isspace(): leading_whitespace_len += 1
 | 
						|
 | 
						|
  s = s.strip()
 | 
						|
  if s == '':
 | 
						|
    return InvalidValueError('Should not be empty')
 | 
						|
 | 
						|
  for i, c in enumerate(s):
 | 
						|
    if not c.isdigit() and not ((c == '-') and (i == 0)):
 | 
						|
      return InvalidValueError('Only digits are allowed, but found char "%s"' % c, Mark('', 1, i+1+leading_whitespace_len))
 | 
						|
 | 
						|
  v = int(s)
 | 
						|
  if min and v < min:
 | 
						|
    return InvalidValueError('Should be greater than or equal to %d' % min, Mark('', 1, leading_whitespace_len))
 | 
						|
  if max and v > max:
 | 
						|
    return InvalidValueError('Should be less than or equal to %d' % max, Mark('', 1, leading_whitespace_len))
 | 
						|
 | 
						|
  return v
 | 
						|
 | 
						|
@type_validator('port')
 | 
						|
def validate_port(s, min=1, max=65535):
 | 
						|
  return validate_integer(s, min=min, max=max)
 | 
						|
 | 
						|
@type_validator('string_list')
 | 
						|
def validate_list(s, element_type='string'):
 | 
						|
  element_type_validator = TypeValidatorRegistry.get_validator(element_type)
 | 
						|
  if not element_type_validator:
 | 
						|
    return SchemaError('Invalid element type "%s"' % element_type)
 | 
						|
 | 
						|
  result = []
 | 
						|
  s = s.strip()
 | 
						|
 | 
						|
  if s == '':
 | 
						|
    return result
 | 
						|
 | 
						|
  values = s.split(',')
 | 
						|
  for value in values:
 | 
						|
    validated_value = element_type_validator.validate(value.strip())
 | 
						|
    if isinstance(validated_value, Issue):
 | 
						|
      # TODO: provide better position reporting
 | 
						|
      return validated_value
 | 
						|
    result.append(validated_value)
 | 
						|
 | 
						|
  return result
 | 
						|
 | 
						|
@type_validator('string_dict')
 | 
						|
def validate_dict(s, element_type='string'):
 | 
						|
  element_type_validator = TypeValidatorRegistry.get_validator(element_type)
 | 
						|
  if not element_type_validator:
 | 
						|
    return SchemaError('Invalid element type "%s"' % element_type)
 | 
						|
 | 
						|
  result = {}
 | 
						|
  s = s.strip()
 | 
						|
 | 
						|
  if s == '':
 | 
						|
    return result
 | 
						|
 | 
						|
  pairs = s.split(',')
 | 
						|
  for pair in pairs:
 | 
						|
    key_value = pair.split(':', 2)
 | 
						|
    if len(key_value) < 2:
 | 
						|
      return InvalidValueError('Value should be NAME:VALUE pairs separated by ","')
 | 
						|
 | 
						|
    key, value = key_value
 | 
						|
    key = key.strip()
 | 
						|
    value = value.strip()
 | 
						|
 | 
						|
    if key == '':
 | 
						|
      # TODO: provide better position reporting
 | 
						|
      return InvalidValueError('Key name should not be empty')
 | 
						|
 | 
						|
    validated_value = element_type_validator.validate(value)
 | 
						|
    if isinstance(validated_value, Issue):
 | 
						|
      # TODO: provide better position reporting
 | 
						|
      return validated_value
 | 
						|
    result[key] = validated_value
 | 
						|
  return result
 | 
						|
 |