# Copyright 2020 Red Hat, Inc.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import functools
import re
import subprocess
import time

from neutron_tempest_plugin.common import shell
from neutron_tempest_plugin.common import utils as common_utils
from oslo_log import log
from tempest import config
from tempest.lib import exceptions

from whitebox_neutron_tempest_plugin.common import constants

CONF = config.CONF
LOG = log.getLogger(__name__)
WB_CONF = CONF.whitebox_neutron_plugin_options


def create_payload_file(ssh_client, size):
    ssh_client.exec_command(
        "head -c {0} /dev/zero > {0}".format(size))


def get_temp_file(ssh_client):
    output_file = ssh_client.exec_command(
        'mktemp').rstrip()
    return output_file


def cat_remote_file(ssh_client, path):
    return ssh_client.exec_command(
        'cat {}'.format(path)).rstrip()


def get_default_interface(ssh_client):
    return ssh_client.exec_command(
        "PATH=$PATH:/usr/sbin ip route get default %s | head -1 | "
        "cut -d ' ' -f 5" % constants.GLOBAL_IP).rstrip()


def get_route_interface(ssh_client, dst_ip):
    output = ssh_client.exec_command(
        "PATH=$PATH:/usr/sbin ip route get default %s | head -1" % dst_ip)
    if output:
        for line in output.splitlines():
            fields = line.strip().split()
            device_index = fields.index('dev') + 1
            return fields[device_index]


def make_sure_local_port_is_open(protocol, port):
    shell.execute_local_command(
        "sudo iptables-save | "
        r"grep 'INPUT.*{protocol}.*\-\-dport {port} \-j ACCEPT' "
        "&& true || "
        "sudo iptables -I INPUT 1 -p {protocol} --dport {port} -j ACCEPT"
        "".format(protocol=protocol, port=port))


# Unlike ncat server function from the upstream plugin this ncat server
# turns itself off automatically after timeout
def run_ncat_server(ssh_client, udp):
    output_file = get_temp_file(ssh_client)
    cmd = "sudo timeout {0} nc -l {1} -p {2} > {3}".format(
        constants.NCAT_TIMEOUT, udp, constants.NCAT_PORT, output_file)
    LOG.debug("Starting nc server: '%s'", cmd)
    ssh_client.open_session().exec_command(cmd)
    return output_file


# Unlike ncat client function from the upstream plugin this ncat client
# is able to run from any host, not only locally
def run_ncat_client(ssh_client, host, udp, payload_size):
    cmd = "nc -w 1 {0} {1} {2} < {3}".format(
        host, udp, constants.NCAT_PORT, payload_size)
    LOG.debug("Starting nc client: '%s'", cmd)
    ssh_client.exec_command(cmd)


def flush_routing_cache(ssh_client):
    ssh_client.exec_command("sudo ip route flush cache")


def kill_iperf_process(ssh_client):
    cmd = "PATH=$PATH:/usr/sbin pkill iperf3"
    try:
        ssh_client.exec_command(cmd)
    except exceptions.SSHExecCommandFailed:
        pass


def configure_interface_up(client, port, interface=None, path=None):
    """configures down interface with ip and activates it

    Parameters:
        client (ssh.Client):ssh client which has interface to configure.
        port (port):port object of interface.
        interface (str):optional interface name on vm.
        path (str):optional shell PATH variable.
    """
    shell_path = path or "PATH=$PATH:/sbin"
    test_interface = interface or client.exec_command(
        "{};ip addr | grep {} -B 1 | head -1 | "
        r"cut -d ':' -f 2 | sed 's/\ //g'".format(
            shell_path, port['mac_address'])).rstrip()

    if CONF.neutron_plugin_options.default_image_is_advanced:
        cmd = ("ip addr show {interface} | grep {ip} || "
               "sudo dhclient {interface}").format(
                   ip=port['fixed_ips'][0]['ip_address'],
                   interface=test_interface)
    else:
        cmd = ("cat /sys/class/net/{interface}/operstate | "
               "grep -q -v down && true || "
               "({path}; sudo ip link set {interface} up && "
               "sudo ip addr add {ip}/24 dev {interface})").format(
                   path=shell_path,
                   ip=port['fixed_ips'][0]['ip_address'],
                   interface=test_interface)

    common_utils.wait_until_true(
        lambda: execute_command_safely(client, cmd), timeout=30, sleep=5)


def parse_dhcp_options_from_nmcli(
        ssh_client, ip_version,
        timeout=20.0, interval=5.0, expected_empty=False, vlan=None):
    # first of all, test ssh connection is available - the time it takes until
    # ssh connection can be established is not cosidered for the nmcli timeout
    ssh_client.test_connection_auth()
    # Add grep -v to exclude loopback interface because
    # Managing the lookback interface using NetworkManager is included in
    # RHEL9.2 image. Previous version is not included.
    cmd_find_connection = 'nmcli -g NAME con show --active | grep -v "^lo"'
    if vlan is not None:
        cmd_find_connection += ' | grep {}'.format(vlan)
    cmd_show_dhcp = ('sudo nmcli -f DHCP{} con show '
                     '"$({})"').format(ip_version, cmd_find_connection)

    start_time = time.time()
    while True:
        try:
            output = ssh_client.exec_command(cmd_show_dhcp)
        except exceptions.SSHExecCommandFailed:
            LOG.warning('Failed to run nmcli on VM - retrying...')
        else:
            if not output and not expected_empty:
                LOG.warning('nmcli result on VM is empty - retrying...')
            else:
                break
        if time.time() - start_time > timeout:
            message = ('Failed to run nmcli on VM after {} '
                       'seconds'.format(timeout))
            raise exceptions.TimeoutException(message)
        time.sleep(interval)

    if not output:
        LOG.warning('Failed to obtain DHCP opts')
        return None
    obtained_dhcp_opts = {}
    for line in output.splitlines():
        newline = re.sub(r'^DHCP{}.OPTION\[[0-9]+\]:\s+'.format(ip_version),
                         '', line.strip())
        option = newline.split('=')[0].strip()
        value = newline.split('=')[1].strip()
        if option in constants.DHCP_OPTIONS_NMCLI_TO_NEUTRON:
            option = constants.DHCP_OPTIONS_NMCLI_TO_NEUTRON[option]
        obtained_dhcp_opts[option] = value
    return obtained_dhcp_opts


def execute_command_safely(ssh_client, command):
    try:
        output = ssh_client.exec_command(command)
    except exceptions.SSHExecCommandFailed as err:
        LOG.warning('command failed: %s', command)
        LOG.exception(err)
        return False
    LOG.debug('command executed successfully: %s\n'
              'command output:\n%s',
              command, output)
    return True


def host_responds_to_ping(ip, count=3):
    cmd = "ping -c{} {}".format(count, ip)
    try:
        subprocess.check_output(['bash', '-c', cmd])
    except subprocess.CalledProcessError:
        return False
    return True


def run_local_cmd(cmd, timeout=10):
    command = "timeout " + str(timeout) + " " + cmd
    LOG.debug("Running local command '%s'", command)
    output, errors = subprocess.Popen(
        command, shell=True, stdout=subprocess.PIPE,
        stderr=subprocess.PIPE).communicate()
    return output, errors


def interface_state_set(client, interface, state):
    shell_path = 'PATH=$PATH:/sbin'
    LOG.debug('Setting interface %s %s on %s',
              interface, state, client.host)
    client.exec_command(
        "{path}; sudo ip link set {interface} {state}".format(
            path=shell_path, interface=interface, state=state))


def remote_service_action(client, service, action, target_state):
    cmd = "sudo systemctl {action} {service}".format(
        action=action, service=service)
    LOG.debug("Running '%s' on %s", cmd, client.host)
    client.exec_command(cmd)
    common_utils.wait_until_true(
        lambda: remote_service_check_state(client, service, target_state),
        timeout=30, sleep=5,
        exception=RuntimeError("Service failed to reach the required "
                               "state '{}'".format(target_state)))


def remote_service_check_state(client, service, state):
    cmd = ("sudo systemctl is-active {service} "
           "| grep -w {state} || true".format(service=service, state=state))
    output = client.exec_command(cmd).strip()
    return (state in output)


# NOTE(mblue): Please use specific regex to avoid dismissing various issues
def retry_on_assert_fail(max_retries,
                         assert_regex,
                         exception_type=AssertionError):
    """Decorator that retries a function up to max_retries times on asser fail
    In order to avoid dismissing exceptions which lead to bugs,
    obligatory regex checked in caught exception message,
    also optional specific exception type can be passed.
    :param max_retries: Obligatory maximum number of retries before failing.
    :param assert_regex: Obligatory regex should be in exception message.
    :param exception_type: Optional specific exception related to failure.
    """
    def decor(f):
        @functools.wraps(f)
        def inner(*args, **kwargs):
            retries = 0
            while retries < max_retries:
                try:
                    return f(*args, **kwargs)
                except exception_type as e:
                    if not (re.search(assert_regex, str(e)) or
                            re.search(assert_regex, repr(e))):
                        raise
                    LOG.debug(
                        f"Assertion failed: {e}. Retrying ({retries + 1}/"
                        f"{max_retries})..."
                    )
                    retries += 1
            raise AssertionError(f"Assert failed after {max_retries} retries.")
        return inner
    return decor


def wait_for_neutron_api(neutron_client, timeout=100):
    """Waits until the Neutron API replies

    :param neutron_client: a Neutron client; it could have or not admin
                           permissions.
    :param timeout: maximum time (in seconds) to wait for the Neutron API.
    """
    def _list_agents():
        try:
            neutron_client.list_extensions()
            return True
        except exceptions.RestClientException:
            return False

    common_utils.wait_until_true(_list_agents, timeout=timeout, sleep=1)


def get_neutron_api_service_name():
    """Return the Neutron API service name based on the test configuration"""
    if WB_CONF.openstack_type == 'devstack':
        # NOTE: in OSP18+, the Neutron API will use WSGI by default (not the
        # eventlet server) and the name will be "neutron api"
        return 'q svc'
    else:
        return 'neutron api'


def get_ml2_conf_file():
    """Neutron ML2 config file name depending on the installation type

    The default value of WB_CONF.ml2_plugin_config is
    '/etc/neutron/plugins/ml2/ml2_conf.ini'.
    """
    if WB_CONF.openstack_type in ('podified', 'devstack'):
        return WB_CONF.ml2_plugin_config
    else:
        return ('/var/lib/config-data/puppet-generated/neutron' +
                WB_CONF.ml2_plugin_config)