"""
Author: Weisen Pan
Date: 2023-10-24
"""
import time
import torch
import numpy as np
import torch.nn.functional as F
from datetime import timedelta
from sklearn import metrics
from tqdm import tqdm
from scheduler import WarmUpLR, downLR


def get_time_difference(start_time):
    """Compute time elapsed from the start_time to now."""
    elapsed_time = time.time() - start_time
    return timedelta(seconds=int(round(elapsed_time)))


def train(config, model, train_iter, dev_iter, test_iter):
    """Train the model and evaluate on the development and test sets."""
    start_time = time.time()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    warmup_epoch = config.num_epochs // 2
    iter_per_epoch = len(train_iter)
    scheduler = downLR(optimizer, (config.num_epochs - warmup_epoch) * iter_per_epoch)
    warmup_scheduler = WarmUpLR(optimizer, warmup_epoch * iter_per_epoch)

    lr_list = np.zeros((config.num_epochs, 2))
    dev_best_loss = float('inf')
    dev_best_acc = 0
    test_best_acc = 0
    total_batch = 0

    for epoch in range(config.num_epochs):
        loss_total = 0
        print(f'Epoch [{epoch + 1}/{config.num_epochs}]')
        
        predictions, true_values = [], []

        for trains, labels in tqdm(train_iter):
            trains, labels = trains.to(config.device), labels.long().to(config.device)
            outputs = model(trains)
            loss = F.cross_entropy(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if epoch < warmup_epoch:
                warmup_scheduler.step()
            else:
                scheduler.step()

            total_batch += 1
            loss_total += loss.item()
            
            predictions.extend(torch.max(outputs.data, 1)[1].tolist())
            true_values.extend(labels.data.tolist())
            
        train_acc = get_accuracy(true_values, predictions)
        dev_acc, dev_loss = evaluate(config, model, dev_iter)
        test_acc, test_loss = evaluate(config, model, test_iter)

        if dev_loss < dev_best_loss:
            dev_best_loss = dev_loss
            improvement_marker = '*'
        else:
            improvement_marker = ''

        if dev_acc > dev_best_acc:
            dev_best_acc = dev_acc
            test_best_acc = test_acc

        elapsed_time = get_time_difference(start_time)
        print((
            f'Iter: {total_batch:6}, Train Loss: {loss_total/len(train_iter):.2f}, '
            f'Train Acc: {train_acc:.2%}, Dev Loss: {dev_loss:.2f}, Dev Acc: {dev_acc:.2%}, '
            f'Test Loss: {test_loss:.2f}, Test Acc: {test_acc:.2%}, Time: {elapsed_time} {improvement_marker}'
        ))
        print(f'Best Dev Acc: {dev_best_acc:.2%}, Best Test Acc: {test_best_acc:.2%}')

    test(config, model, test_iter)


def test(config, model, test_iter):
    """Evaluate the model on the test set."""
    model.eval()
    start_time = time.time()
    test_acc, test_loss, test_confusion = evaluate(config, model, test_iter, test=True)
    print(f'Test Loss: {test_loss:.2f}, Test Acc: {test_acc:.2%}')
    print(test_confusion)
    elapsed_time = get_time_difference(start_time)
    print(f"Time usage: {elapsed_time}")


def evaluate(config, model, data_iter, test=False):
    """Evaluate the model on a given dataset."""
    model.eval()
    loss_total = 0
    predictions, true_values = [], []

    with torch.no_grad():
        for texts, labels in data_iter:
            texts, labels = texts.float().to(config.device), labels.long().to(config.device)
            outputs = model(texts)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss.item()

            predictions.extend(torch.max(outputs.data, 1)[1].tolist())
            true_values.extend(labels.data.tolist())

    acc = get_accuracy(true_values, predictions)

    if test:
        confusion = metrics.confusion_matrix(true_values, predictions)
        return acc, loss_total / len(data_iter), confusion

    return acc, loss_total / len(data_iter)


def get_accuracy(y_true, y_pred):
    """Calculate accuracy."""
    return metrics.accuracy_score(y_true, y_pred)