# Copyright (c) 2014 Mirantis Inc.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import functools
import logging
from multiprocessing import managers
import os
import shutil
import signal
import stat
import sys
import tempfile
import threading
import time

from oslo_rootwrap import cmd
from oslo_rootwrap import jsonrpc
from oslo_rootwrap import subprocess
from oslo_rootwrap import wrapper

LOG = logging.getLogger(__name__)

# Since multiprocessing supports only pickle and xmlrpclib for serialization of
# RPC requests and responses, we declare another 'jsonrpc' serializer

managers.listener_client['jsonrpc'] = jsonrpc.JsonListener, jsonrpc.JsonClient


class RootwrapClass(object):
    def __init__(self, config, filters):
        self.config = config
        self.filters = filters
        self.reset_timer()
        self.prepare_timer(config)

    def run_one_command(self, userargs, stdin=None):
        self.reset_timer()
        try:
            obj = wrapper.start_subprocess(
                self.filters, userargs,
                exec_dirs=self.config.exec_dirs,
                log=self.config.use_syslog,
                close_fds=True,
                stdin=subprocess.PIPE,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE)
        except wrapper.FilterMatchNotExecutable:
            LOG.warning("Executable not found for: %s",
                        ' '.join(userargs))
            return cmd.RC_NOEXECFOUND, "", ""

        except wrapper.NoFilterMatched:
            LOG.warning("Unauthorized command: %s (no filter matched)",
                        ' '.join(userargs))
            return cmd.RC_UNAUTHORIZED, "", ""

        if stdin is not None:
            stdin = os.fsencode(stdin)
        out, err = obj.communicate(stdin)
        out = os.fsdecode(out)
        err = os.fsdecode(err)
        return obj.returncode, out, err

    @classmethod
    def reset_timer(cls):
        cls.last_called = time.time()

    @classmethod
    def cancel_timer(cls):
        try:
            cls.timeout.cancel()
        except RuntimeError:
            pass

    @classmethod
    def prepare_timer(cls, config=None):
        if config is not None:
            cls.daemon_timeout = config.daemon_timeout
        # Wait a bit longer to avoid rounding errors
        timeout = max(
            cls.last_called + cls.daemon_timeout - time.time(),
            0) + 1
        if getattr(cls, 'timeout', None):
            # Another timer is already initialized
            return
        cls.timeout = threading.Timer(timeout, cls.handle_timeout)
        cls.timeout.start()

    @classmethod
    def handle_timeout(cls):
        if cls.last_called < time.time() - cls.daemon_timeout:
            cls.shutdown()

        cls.prepare_timer()

    @staticmethod
    def shutdown():
        # Suicide to force break of the main thread
        os.kill(os.getpid(), signal.SIGINT)


def get_manager_class(config=None, filters=None):
    class RootwrapManager(managers.BaseManager):
        def __init__(self, address=None, authkey=None):
            # Force jsonrpc because neither pickle nor xmlrpclib is secure
            super(RootwrapManager, self).__init__(address, authkey,
                                                  serializer='jsonrpc')

    if config is not None:
        partial_class = functools.partial(RootwrapClass, config, filters)
        RootwrapManager.register('rootwrap', partial_class)
    else:
        RootwrapManager.register('rootwrap')

    return RootwrapManager


def daemon_start(config, filters):
    temp_dir = tempfile.mkdtemp(prefix='rootwrap-')
    LOG.debug("Created temporary directory %s", temp_dir)
    try:
        # allow everybody to find the socket
        rwxr_xr_x = (stat.S_IRWXU |
                     stat.S_IRGRP | stat.S_IXGRP |
                     stat.S_IROTH | stat.S_IXOTH)
        os.chmod(temp_dir, rwxr_xr_x)
        socket_path = os.path.join(temp_dir, "rootwrap.sock")
        LOG.debug("Will listen on socket %s", socket_path)
        manager_cls = get_manager_class(config, filters)
        manager = manager_cls(address=socket_path)
        server = manager.get_server()
        try:
            # allow everybody to connect to the socket
            rw_rw_rw_ = (stat.S_IRUSR | stat.S_IWUSR |
                         stat.S_IRGRP | stat.S_IWGRP |
                         stat.S_IROTH | stat.S_IWOTH)
            os.chmod(socket_path, rw_rw_rw_)
            try:
                # In Python 3 we have to use buffer to push in bytes directly
                stdout = sys.stdout.buffer
            except AttributeError:
                stdout = sys.stdout
            stdout.write(socket_path.encode('utf-8'))
            stdout.write(b'\n')
            stdout.write(bytes(server.authkey))
            sys.stdin.close()
            sys.stdout.close()
            sys.stderr.close()
            # Gracefully shutdown on INT or TERM signals
            stop = functools.partial(daemon_stop, server)
            signal.signal(signal.SIGTERM, stop)
            signal.signal(signal.SIGINT, stop)
            LOG.info("Starting rootwrap daemon main loop")
            server.serve_forever()
        finally:
            conn = server.listener
            # This will break accept() loop with EOFError if it was not in the
            # main thread (as in Python 3.x)
            conn.close()
            # Closing all currently connected client sockets for reading to
            # break worker threads blocked on recv()
            for cl_conn in conn.get_accepted():
                try:
                    cl_conn.half_close()
                except Exception:
                    # Most likely the socket have already been closed
                    LOG.debug("Failed to close connection")
            RootwrapClass.cancel_timer()
            LOG.info("Waiting for all client threads to finish.")
            for thread in threading.enumerate():
                if thread.daemon:
                    LOG.debug("Joining thread %s", thread)
                    thread.join()
    finally:
        LOG.debug("Removing temporary directory %s", temp_dir)
        shutil.rmtree(temp_dir)


def daemon_stop(server, signal, frame):
    LOG.info("Got signal %s. Shutting down server", signal)
    # Signals are caught in the main thread which means this handler will run
    # in the middle of serve_forever() loop. It will catch this exception and
    # properly return. Since all threads created by server_forever are
    # daemonic, we need to join them afterwards. In Python 3 we can just hit
    # stop_event instead.
    try:
        server.stop_event.set()
    except AttributeError:
        raise KeyboardInterrupt