#! /usr/bin/env python3

# Copyright 2011, 2013-2014 OpenStack Foundation
# Copyright 2012 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import configparser
import argparse
import irc.client
import logging
import ssl
import sys
import time
import yaml

logging.basicConfig(
    format='%(asctime)s [%(levelname)s] %(name)s - %(message)s',
    level=logging.DEBUG)


class SetAccess(irc.client.SimpleIRCClient):
    log = logging.getLogger("setaccess")

    def __init__(self, config, noop, nick, password, server, port):
        irc.client.SimpleIRCClient.__init__(self)
        self.config = config
        self.nick = nick
        self.password = password
        self.server = server
        self.port = int(port)
        self.noop = noop
        self.channels = [x['name'] for x in self.config['channels']]
        self.current_channel = None
        self.current_list = []
        self.current_mode = ''
        self.changes = []
        self.identified = False
        if self.port == 6697:
            factory = irc.connection.Factory(wrapper=ssl.wrap_socket)
            self.connect(self.server, self.port, self.nick,
                         connect_factory=factory)
        else:
            self.connect(self.server, self.port, self.nick)

    def on_disconnect(self, connection, event):
        sys.exit(0)

    def on_privnotice(self, c, e):
        nick = e.source.split('!')[0]
        msg = e.arguments[0]
        if nick == 'NickServ' and not self.identified:
            if msg.startswith('authenticate yourself to services'):
                self.log.debug("Identifying to nickserv")
                # TODO (fungi): We should protect against sending our
                # password to a false NickServ, perhaps with
                # https://www.oftc.net/NickServ/CertFP/ or eventually
                # SASL once the ircd implements that
                c.privmsg("nickserv", "identify %s " % self.password)
                return
            elif msg.startswith('You are successfully identified'):
                self.identified = True
                # Prejoin and set ourselves as op in these channels,
                # to facilitate +f forwarding.
                for channel in self.config.get('op_channels', []):
                    c.join("#%s" % channel)
                    c.privmsg("chanserv", "op #%s" % channel)
                self.advance()
                return
            else:
                return
        if nick not in ('ChanServ', 'NickServ'):
            self.log.debug("Ignoring message from non-ChanServ "
                           "user %s" % nick)
            return
        self.failed = False
        self.advance(msg)

    def _get_access_list(self, channel_name):
        ret = {}
        alumni = []
        mode = ''
        level = ''
        channel = None
        for c in self.config['channels']:
            if c['name'] == channel_name:
                channel = c
        if channel is None:
            raise Exception("Unknown channel %s" % (channel_name,))
        for key, value in (list(self.config['global'].items()) +
                              list(channel.items())):
            if key == 'alumni':
                alumni += value
                continue
            if key == 'mode':
                mode = value
                continue

            # If we get this far, we assume the key is an access
            # level matching an entry in the access list
            level = self.config['access'].get(key)
            if level is None:
                # Skip if this doesn't match a defined access level
                continue
            for nick in value:
                ret[nick] = level
        return ret, alumni, mode

    def _get_access_change(self, current, target):
        if current != target:
            return target

    def _get_access_changes(self):
        target, alumni, mode = self._get_access_list(
            self.current_channel)
        self.log.debug("Target #%s ACL: %s" % (self.current_channel, target))
        all_nicks = set()
        global_alumni = self.config.get('alumni', {})
        global_mode = self.config.get('mode', '')
        current = {}
        changes = []
        for nick, level, msg in self.current_list:
            if nick in global_alumni or nick in alumni :
                self.log.debug("%s is an alumni; removing access", nick)
                changes.append('access #%s del %s' % (self.current_channel, nick))
                continue
            all_nicks.add(nick)
            current[nick] = level
        for nick in target.keys():
            all_nicks.add(nick)
        for nick in all_nicks:
            change = self._get_access_change(current.get(nick, ''),
                                             target.get(nick, ''))
            if change:
                changes.append('access #%s add %s %s' % (self.current_channel,
                                                         nick, change))

        # Set the mode if what we want differs from what's already there.
        # Channel mode overrides global mode.
        if not mode and global_mode:
            mode = global_mode
        if not mode:
            mode = '+'
        if sorted(mode) != sorted(self.current_mode):
            self.log.debug("Current mode for #%s is %s, replacing with %s" % (
                self.current_channel, self.current_mode, mode))
            changes.append('set #%s mlock %s' % (self.current_channel, mode))

        return changes

    def advance(self, msg=None):
        # Some service responses include a number of embedded 0x02 bytes
        if msg:
            msg = msg.replace('\x02', '')
        if self.changes:
            if self.noop:
                for change in self.changes:
                    self.log.info('NOOP: ' + change)
                self.changes = []
            else:
                change = self.changes.pop()
                self.log.info(change)
                self.connection.privmsg('chanserv', change)
                time.sleep(1)
                return
        if not self.current_channel:
            if not self.channels:
                self.connection.quit()
                return
            self.current_channel = self.channels.pop()
            # Clear the mode string before we request it, so if we get
            # no response we won't have the modes from an earlier channel
            self.current_mode = ''
            # Sending a set mlock with no value prompts the service to
            # respond with the current mlock value so we can compare
            # against it later
            self.connection.privmsg('chanserv', 'set #%s mlock' %
                                    self.current_channel)
            # Clear the access list before we request it, so if we get
            # no response we won't have the list from an earlier channel
            self.current_list = []
            self.connection.privmsg('chanserv', 'access #%s list' %
                                    self.current_channel)
            time.sleep(1)
            return
        # We tokenize every server message, and perform some rough
        # heuristics in order to determine what kind of response we're
        # dealing with and whether it's something we know how to parse
        parts = msg.split()
        # If the third word look like an access level, assume this is
        # an access list entry and that the second word is a
        # corresponding nick
        if parts[2] in ('MASTER', 'CHANOP', 'MEMBER'):
            self.current_list.append((parts[1], parts[2], msg))
        # If the message starts with "MLOCK is SET to" then assume the
        # fifth word is the channel's mode string
        elif msg.startswith('MLOCK is SET to'):
            self.current_mode = parts[4]
        # If the message starts with "End of" then assume this marks
        # the end of an access list
        elif msg.startswith('End of'):
            self.changes = self._get_access_changes()
            self.current_channel = None
            self.advance()
            return


def main():
    parser = argparse.ArgumentParser(description='IRC channel access check')
    parser.add_argument('-c', dest='config', nargs=1,
                        help='specify the config file')
    parser.add_argument('-l', dest='channels',
                        default='/etc/irc/channels.yaml',
                        help='path to the channel config')
    parser.add_argument('--noop', dest='noop',
                        action='store_true',
                        help="Don't make any changes")
    args = parser.parse_args()

    config = configparser.ConfigParser()
    config.read(args.config)

    channels = yaml.safe_load(open(args.channels))

    a = SetAccess(channels, args.noop,
                  config.get('ircbot', 'nick'),
                  config.get('ircbot', 'pass'),
                  config.get('ircbot', 'server'),
                  config.get('ircbot', 'port'))
    a.start()


if __name__ == "__main__":
    main()