"""
Author: Weisen Pan
Date: 2023-10-24
"""
import time
import torch
import numpy as np
import torch.nn.functional as F
from datetime import timedelta
from sklearn import metrics
from tqdm import tqdm
from scheduler import WarmUpLR, downLR

def get_time_dif(start_time):
    """Get the time difference between now and the start time."""
    elapsed_time = time.time() - start_time
    return timedelta(seconds=int(round(elapsed_time)))

def train(config, model, train_iter, dev_iter, test_iter):
    start_time = time.time()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    warmup_steps = config.num_epochs / 2 * len(train_iter)
    scheduler = downLR(optimizer, (config.num_epochs - warmup_steps / len(train_iter)) * len(train_iter))
    warmup_scheduler = WarmUpLR(optimizer, warmup_steps)
    
    dev_best_loss = float('inf')
    dev_best_acc = 0
    test_best_acc = 0

    for epoch in range(config.num_epochs):
        epoch_loss = 0
        predictions, labels = [], []

        for trains, label_batch, poss, masks in tqdm(train_iter):
            trains, label_batch, poss, masks = [tensor.to(config.device) for tensor in [trains, label_batch, poss, masks]]
            
            outputs = model(trains, poss, masks)
            model.zero_grad()
            loss = F.cross_entropy(outputs, label_batch)
            loss.backward()
            optimizer.step()
            
            if epoch < warmup_steps / len(train_iter):
                warmup_scheduler.step()
            else:
                scheduler.step()
                
            epoch_loss += loss.item()
            predictions.extend(torch.max(outputs, 1)[1].tolist())
            labels.extend(label_batch.tolist())

        train_acc = metrics.accuracy_score(labels, predictions)
        dev_acc, dev_loss = evaluate(config, model, dev_iter)

        if dev_loss < dev_best_loss:
            dev_best_loss = dev_loss
        if dev_acc > dev_best_acc:
            dev_best_acc = dev_acc
            test_best_acc = evaluate(config, model, test_iter)[0]

        time_dif = get_time_dif(start_time)
        print(f'Epoch: {epoch + 1}/{config.num_epochs}, Train Loss: {epoch_loss / len(train_iter):.2f}, Train Acc: {train_acc:.2%}, Dev Loss: {dev_loss:.2f}, Dev Acc: {dev_acc:.2%}, Test Best Acc: {test_best_acc:.2%}, Time: {time_dif}')

    test(config, model, test_iter)

def test(config, model, test_iter):
    model.eval()
    test_acc, test_loss, test_confusion = evaluate(config, model, test_iter, test=True)
    print(f'Test Loss: {test_loss:.2f}, Test Acc: {test_acc:.2%}')
    print("Confusion Matrix:", test_confusion)
    print("Time usage:", get_time_dif(time.time()))

def evaluate(config, model, data_iter, test=False):
    model.eval()
    total_loss = 0
    predictions, labels = [], []

    with torch.no_grad():
        for texts, labels_batch, poss, masks in data_iter:
            texts, poss, masks, labels_batch = [tensor.to(config.device) for tensor in [texts, poss, masks, labels_batch]]
            outputs = model(texts, poss, masks)
            loss = F.cross_entropy(outputs, labels_batch)
            
            total_loss += loss.item()
            predictions.extend(torch.max(outputs, 1)[1].tolist())
            labels.extend(labels_batch.tolist())

    accuracy = metrics.accuracy_score(labels, predictions)

    if test:
        confusion = metrics.confusion_matrix(labels, predictions)
        return accuracy, total_loss / len(data_iter), confusion
    return accuracy, total_loss / len(data_iter)