458 lines
15 KiB
Python
458 lines
15 KiB
Python
# Copyright 2016 Huawei Technologies Co.,LTD.
|
|
# All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""Utilities and helper functions."""
|
|
|
|
import base64
|
|
import binascii
|
|
import contextlib
|
|
import eventlet
|
|
import functools
|
|
import inspect
|
|
import os
|
|
import re
|
|
import shutil
|
|
import tempfile
|
|
import traceback
|
|
|
|
from cryptography.hazmat import backends
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography import x509
|
|
from oslo_concurrency import lockutils
|
|
from oslo_concurrency import processutils
|
|
from oslo_context import context as common_context
|
|
from oslo_log import log as logging
|
|
from oslo_utils import encodeutils
|
|
from oslo_utils import importutils
|
|
import paramiko
|
|
import six
|
|
|
|
from mogan.common import exception
|
|
from mogan.common.i18n import _
|
|
from mogan.common import states
|
|
from mogan.conf import CONF
|
|
from mogan import objects
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
synchronized = lockutils.synchronized_with_prefix('mogan-')
|
|
profiler = importutils.try_import('osprofiler.profiler')
|
|
|
|
|
|
def safe_rstrip(value, chars=None):
|
|
"""Removes trailing characters from a string if that does not make it empty
|
|
|
|
:param value: A string value that will be stripped.
|
|
:param chars: Characters to remove.
|
|
:return: Stripped value.
|
|
|
|
"""
|
|
if not isinstance(value, six.string_types):
|
|
LOG.warning("Failed to remove trailing character. Returning "
|
|
"original object. Supplied object is not a string: "
|
|
"%s,", value)
|
|
return value
|
|
|
|
return value.rstrip(chars) or value
|
|
|
|
|
|
def is_valid_mac(address):
|
|
"""Verify the format of a MAC address.
|
|
|
|
Check if a MAC address is valid and contains six octets. Accepts
|
|
colon-separated format only.
|
|
|
|
:param address: MAC address to be validated.
|
|
:returns: True if valid. False if not.
|
|
|
|
"""
|
|
m = "[0-9a-f]{2}(:[0-9a-f]{2}){5}$"
|
|
return (isinstance(address, six.string_types) and
|
|
re.match(m, address.lower()))
|
|
|
|
|
|
def validate_and_normalize_mac(address):
|
|
"""Validate a MAC address and return normalized form.
|
|
|
|
Checks whether the supplied MAC address is formally correct and
|
|
normalize it to all lower case.
|
|
|
|
:param address: MAC address to be validated and normalized.
|
|
:returns: Normalized and validated MAC address.
|
|
:raises InvalidMAC: If the MAC address is not valid.
|
|
|
|
"""
|
|
if not is_valid_mac(address):
|
|
raise exception.InvalidMAC(mac=address)
|
|
return address.lower()
|
|
|
|
|
|
def make_pretty_name(method):
|
|
"""Makes a pretty name for a function/method."""
|
|
meth_pieces = [method.__name__]
|
|
# If its a server method attempt to tack on the class name
|
|
if hasattr(method, '__self__') and method.__self__ is not None:
|
|
try:
|
|
meth_pieces.insert(0, method.__self__.__class__.__name__)
|
|
except AttributeError:
|
|
pass
|
|
return ".".join(meth_pieces)
|
|
|
|
|
|
def check_isinstance(obj, cls):
|
|
"""Checks that obj is of type cls, and lets PyLint infer types."""
|
|
if isinstance(obj, cls):
|
|
return obj
|
|
raise Exception(_('Expected object of type: %s') % (str(cls)))
|
|
|
|
|
|
def get_state_machine(start_state=None, target_state=None):
|
|
# Initialize state machine
|
|
fsm = states.machine.copy()
|
|
fsm.initialize(start_state=start_state, target_state=target_state)
|
|
return fsm
|
|
|
|
|
|
def process_event(fsm, server, event=None):
|
|
fsm.process_event(event)
|
|
server.status = fsm.current_state
|
|
server.save()
|
|
|
|
|
|
def get_wrapped_function(function):
|
|
"""Get the method at the bottom of a stack of decorators."""
|
|
if not hasattr(function, '__closure__') or not function.__closure__:
|
|
return function
|
|
|
|
def _get_wrapped_function(function):
|
|
if not hasattr(function, '__closure__') or not function.__closure__:
|
|
return None
|
|
|
|
for closure in function.__closure__:
|
|
func = closure.cell_contents
|
|
|
|
deeper_func = _get_wrapped_function(func)
|
|
if deeper_func:
|
|
return deeper_func
|
|
elif hasattr(closure.cell_contents, '__call__'):
|
|
return closure.cell_contents
|
|
|
|
return _get_wrapped_function(function)
|
|
|
|
|
|
def expects_func_args(*args):
|
|
def _decorator_checker(dec):
|
|
@functools.wraps(dec)
|
|
def _decorator(f):
|
|
base_f = get_wrapped_function(f)
|
|
arg_names, a, kw, _default = inspect.getargspec(base_f)
|
|
if a or kw or set(args) <= set(arg_names):
|
|
return dec(f)
|
|
else:
|
|
raise TypeError("Decorated function %(f_name)s does not "
|
|
"have the arguments expected by the "
|
|
"decorator %(d_name)s" %
|
|
{'f_name': base_f.__name__,
|
|
'd_name': dec.__name__})
|
|
return _decorator
|
|
return _decorator_checker
|
|
|
|
|
|
def _get_fault_detail(exc_info, error_code):
|
|
details = ''
|
|
if exc_info and error_code == 500:
|
|
tb = exc_info[2]
|
|
if tb:
|
|
details = ''.join(traceback.format_tb(tb))
|
|
return six.text_type(details)
|
|
|
|
|
|
def safe_truncate(value, length):
|
|
"""Safely truncates unicode strings such that their encoded length is
|
|
no greater than the length provided.
|
|
"""
|
|
b_value = encodeutils.safe_encode(value)[:length]
|
|
|
|
decode_ok = False
|
|
# NOTE[ZhongLuyao] UTF-8 character byte size varies from 1 to 6. If
|
|
# truncating a long byte string to 255, the last character may be
|
|
# cut in the middle, so that UnicodeDecodeError will occur when
|
|
# converting it back to unicode.
|
|
while not decode_ok:
|
|
try:
|
|
u_value = encodeutils.safe_decode(b_value)
|
|
decode_ok = True
|
|
except UnicodeDecodeError:
|
|
b_value = b_value[:-1]
|
|
return u_value
|
|
|
|
|
|
def add_server_fault_from_exc(context, server, fault, exc_info=None,
|
|
fault_message=None):
|
|
"""Adds the specified fault to the database."""
|
|
code = 500
|
|
if hasattr(fault, "kwargs"):
|
|
code = fault.kwargs.get('code', 500)
|
|
try:
|
|
if not fault_message:
|
|
fault_message = fault.format_message()
|
|
except Exception:
|
|
try:
|
|
fault_message = six.text_type(fault)
|
|
except Exception:
|
|
fault_message = None
|
|
if not fault_message:
|
|
fault_message = fault.__class__.__name__
|
|
message = safe_truncate(fault_message, 255)
|
|
fault_dict = dict(exception=fault)
|
|
fault_dict["message"] = message
|
|
fault_dict["code"] = code
|
|
fault_obj = objects.ServerFault(context=context)
|
|
fault_obj.server_uuid = server.uuid
|
|
fault_obj.update(fault_dict)
|
|
code = fault_obj.code
|
|
fault_obj.detail = _get_fault_detail(exc_info, code)
|
|
fault_obj.create()
|
|
|
|
|
|
def execute(*cmd, **kwargs):
|
|
"""Convenience wrapper around oslo's execute() method.
|
|
|
|
:param cmd: Passed to processutils.execute.
|
|
:param use_standard_locale: True | False. Defaults to False. If set to
|
|
True, execute command with standard locale
|
|
added to environment variables.
|
|
:returns: (stdout, stderr) from process execution
|
|
:raises: UnknownArgumentError
|
|
:raises: ProcessExecutionError
|
|
"""
|
|
|
|
use_standard_locale = kwargs.pop('use_standard_locale', False)
|
|
if use_standard_locale:
|
|
env = kwargs.pop('env_variables', os.environ.copy())
|
|
env['LC_ALL'] = 'C'
|
|
kwargs['env_variables'] = env
|
|
result = processutils.execute(*cmd, **kwargs)
|
|
LOG.debug('Execution completed, command line is "%s"',
|
|
' '.join(map(str, cmd)))
|
|
LOG.debug('Command stdout is: "%s"', result[0])
|
|
LOG.debug('Command stderr is: "%s"', result[1])
|
|
return result
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tempdir(**kwargs):
|
|
tempfile.tempdir = CONF.tempdir
|
|
tmpdir = tempfile.mkdtemp(**kwargs)
|
|
try:
|
|
yield tmpdir
|
|
finally:
|
|
try:
|
|
shutil.rmtree(tmpdir)
|
|
except OSError as e:
|
|
LOG.error('Could not remove tmpdir: %s', e)
|
|
|
|
|
|
def mkfs(fs, path, label=None, run_as_root=False):
|
|
"""Format a file or block device
|
|
|
|
:param fs: Filesystem type (examples include 'swap', 'ext3', 'ext4'
|
|
'btrfs', etc.)
|
|
:param path: Path to file or block device to format
|
|
:param label: Volume label to use
|
|
"""
|
|
if fs == 'swap':
|
|
args = ['mkswap']
|
|
else:
|
|
args = ['mkfs', '-t', fs]
|
|
# add -F to force no interactive execute on non-block device.
|
|
if fs in ('ext3', 'ext4', 'ntfs'):
|
|
args.extend(['-F'])
|
|
if label:
|
|
if fs in ('msdos', 'vfat'):
|
|
label_opt = '-n'
|
|
else:
|
|
label_opt = '-L'
|
|
args.extend([label_opt, label])
|
|
args.append(path)
|
|
execute(*args, run_as_root=run_as_root)
|
|
|
|
|
|
def trycmd(*args, **kwargs):
|
|
"""Convenience wrapper around oslo's trycmd() method."""
|
|
return processutils.trycmd(*args, **kwargs)
|
|
|
|
|
|
def check_string_length(value, name=None, min_length=0, max_length=None):
|
|
"""Check the length of specified string
|
|
:param value: the value of the string
|
|
:param name: the name of the string
|
|
:param min_length: the min_length of the string
|
|
:param max_length: the max_length of the string
|
|
"""
|
|
if not isinstance(value, six.string_types):
|
|
if name is None:
|
|
msg = "The input is not a string or unicode"
|
|
else:
|
|
msg = "%s is not a string or unicode" % name
|
|
raise exception.Invalid(message=msg)
|
|
|
|
if name is None:
|
|
name = value
|
|
|
|
if len(value) < min_length:
|
|
msg = _("%(name)s has a minimum character requirement of "
|
|
"%(min_length)s.") % {'name': name, 'min_length': min_length}
|
|
raise exception.Invalid(message=msg)
|
|
|
|
if max_length and len(value) > max_length:
|
|
msg = _("%(name)s has more than %(max_length)s "
|
|
"characters.") % {'name': name, 'max_length': max_length}
|
|
raise exception.Invalid(message=msg)
|
|
|
|
|
|
def _create_x509_openssl_config(conffile, upn):
|
|
content = ("distinguished_name = req_distinguished_name\n"
|
|
"[req_distinguished_name]\n"
|
|
"[v3_req_client]\n"
|
|
"extendedKeyUsage = clientAuth\n"
|
|
"subjectAltName = otherName:""1.3.6.1.4.1.311.20.2.3;UTF8:%s\n")
|
|
|
|
with open(conffile, 'w') as file:
|
|
file.write(content % upn)
|
|
|
|
|
|
def generate_winrm_x509_cert(user_id, bits=2048):
|
|
"""Generate a cert for passwordless auth for user in project."""
|
|
subject = '/CN=%s' % user_id
|
|
upn = '%s@localhost' % user_id
|
|
|
|
with tempdir() as tmpdir:
|
|
keyfile = os.path.abspath(os.path.join(tmpdir, 'temp.key'))
|
|
conffile = os.path.abspath(os.path.join(tmpdir, 'temp.conf'))
|
|
|
|
_create_x509_openssl_config(conffile, upn)
|
|
|
|
(certificate, _err) = execute(
|
|
'openssl', 'req', '-x509', '-nodes', '-days', '3650',
|
|
'-config', conffile, '-newkey', 'rsa:%s' % bits,
|
|
'-outform', 'PEM', '-keyout', keyfile, '-subj', subject,
|
|
'-extensions', 'v3_req_client',
|
|
binary=True)
|
|
|
|
(out, _err) = execute('openssl', 'pkcs12', '-export',
|
|
'-inkey', keyfile, '-password', 'pass:',
|
|
process_input=certificate,
|
|
binary=True)
|
|
|
|
private_key = base64.b64encode(out)
|
|
fingerprint = generate_x509_fingerprint(certificate)
|
|
if six.PY3:
|
|
private_key = private_key.decode('ascii')
|
|
certificate = certificate.decode('utf-8')
|
|
|
|
return (private_key, certificate, fingerprint)
|
|
|
|
|
|
def generate_fingerprint(public_key):
|
|
try:
|
|
pub_bytes = public_key.encode('utf-8')
|
|
# Test that the given public_key string is a proper ssh key. The
|
|
# returned object is unused since pyca/cryptography does not have a
|
|
# fingerprint method.
|
|
serialization.load_ssh_public_key(
|
|
pub_bytes, backends.default_backend())
|
|
pub_data = base64.b64decode(public_key.split(' ')[1])
|
|
digest = hashes.Hash(hashes.MD5(), backends.default_backend())
|
|
digest.update(pub_data)
|
|
md5hash = digest.finalize()
|
|
raw_fp = binascii.hexlify(md5hash)
|
|
if six.PY3:
|
|
raw_fp = raw_fp.decode('ascii')
|
|
return ':'.join(a + b for a, b in zip(raw_fp[::2], raw_fp[1::2]))
|
|
except Exception:
|
|
raise exception.InvalidKeypair(
|
|
reason=_('failed to generate fingerprint'))
|
|
|
|
|
|
def generate_x509_fingerprint(pem_key):
|
|
try:
|
|
if isinstance(pem_key, six.text_type):
|
|
pem_key = pem_key.encode('utf-8')
|
|
cert = x509.load_pem_x509_certificate(
|
|
pem_key, backends.default_backend())
|
|
raw_fp = binascii.hexlify(cert.fingerprint(hashes.SHA1()))
|
|
if six.PY3:
|
|
raw_fp = raw_fp.decode('ascii')
|
|
return ':'.join(a + b for a, b in zip(raw_fp[::2], raw_fp[1::2]))
|
|
except (ValueError, TypeError, binascii.Error) as ex:
|
|
raise exception.InvalidKeypair(
|
|
reason=_('failed to generate X509 fingerprint. '
|
|
'Error message: %s') % ex)
|
|
|
|
|
|
def generate_key_pair(bits=2048):
|
|
key = paramiko.RSAKey.generate(bits)
|
|
keyout = six.StringIO()
|
|
key.write_private_key(keyout)
|
|
private_key = keyout.getvalue()
|
|
public_key = '%s %s Generated-by-Mogan' % (
|
|
key.get_name(), key.get_base64())
|
|
fingerprint = generate_fingerprint(public_key)
|
|
return (private_key, public_key, fingerprint)
|
|
|
|
|
|
def _serialize_profile_info():
|
|
if not profiler:
|
|
return None
|
|
prof = profiler.get()
|
|
trace_info = None
|
|
if prof:
|
|
# FIXME(DinaBelova): we'll add profiler.get_info() method
|
|
# to extract this info -> we'll need to update these lines
|
|
trace_info = {
|
|
"hmac_key": prof.hmac_key,
|
|
"base_id": prof.get_base_id(),
|
|
"parent_id": prof.get_id()
|
|
}
|
|
return trace_info
|
|
|
|
|
|
def spawn_n(func, *args, **kwargs):
|
|
"""Passthrough method for eventlet.spawn_n.
|
|
|
|
This utility exists so that it can be stubbed for testing without
|
|
interfering with the service spawns.
|
|
|
|
It will also grab the context from the threadlocal store and add it to
|
|
the store on the new thread. This allows for continuity in logging the
|
|
context when using this method to spawn a new thread.
|
|
"""
|
|
_context = common_context.get_current()
|
|
profiler_info = _serialize_profile_info()
|
|
|
|
@functools.wraps(func)
|
|
def context_wrapper(*args, **kwargs):
|
|
# NOTE: If update_store is not called after spawn_n it won't be
|
|
# available for the logger to pull from threadlocal storage.
|
|
if _context is not None:
|
|
_context.update_store()
|
|
if profiler_info and profiler:
|
|
profiler.init(**profiler_info)
|
|
func(*args, **kwargs)
|
|
|
|
eventlet.spawn_n(context_wrapper, *args, **kwargs)
|