""" Author: Weisen Pan Date: 2023-10-24 """ import time import torch import numpy as np import torch.nn.functional as F from datetime import timedelta from sklearn import metrics from tqdm import tqdm from scheduler import WarmUpLR, downLR def get_time_difference(start_time): """Compute time elapsed from the start_time to now.""" elapsed_time = time.time() - start_time return timedelta(seconds=int(round(elapsed_time))) def train(config, model, train_iter, dev_iter, test_iter): """Train the model and evaluate on the development and test sets.""" start_time = time.time() model.train() optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) warmup_epoch = config.num_epochs // 2 iter_per_epoch = len(train_iter) scheduler = downLR(optimizer, (config.num_epochs - warmup_epoch) * iter_per_epoch) warmup_scheduler = WarmUpLR(optimizer, warmup_epoch * iter_per_epoch) lr_list = np.zeros((config.num_epochs, 2)) dev_best_loss = float('inf') dev_best_acc = 0 test_best_acc = 0 total_batch = 0 for epoch in range(config.num_epochs): loss_total = 0 print(f'Epoch [{epoch + 1}/{config.num_epochs}]') predictions, true_values = [], [] for trains, labels in tqdm(train_iter): trains, labels = trains.to(config.device), labels.long().to(config.device) outputs = model(trains) loss = F.cross_entropy(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() if epoch < warmup_epoch: warmup_scheduler.step() else: scheduler.step() total_batch += 1 loss_total += loss.item() predictions.extend(torch.max(outputs.data, 1)[1].tolist()) true_values.extend(labels.data.tolist()) train_acc = get_accuracy(true_values, predictions) dev_acc, dev_loss = evaluate(config, model, dev_iter) test_acc, test_loss = evaluate(config, model, test_iter) if dev_loss < dev_best_loss: dev_best_loss = dev_loss improvement_marker = '*' else: improvement_marker = '' if dev_acc > dev_best_acc: dev_best_acc = dev_acc test_best_acc = test_acc elapsed_time = get_time_difference(start_time) print(( f'Iter: {total_batch:6}, Train Loss: {loss_total/len(train_iter):.2f}, ' f'Train Acc: {train_acc:.2%}, Dev Loss: {dev_loss:.2f}, Dev Acc: {dev_acc:.2%}, ' f'Test Loss: {test_loss:.2f}, Test Acc: {test_acc:.2%}, Time: {elapsed_time} {improvement_marker}' )) print(f'Best Dev Acc: {dev_best_acc:.2%}, Best Test Acc: {test_best_acc:.2%}') test(config, model, test_iter) def test(config, model, test_iter): """Evaluate the model on the test set.""" model.eval() start_time = time.time() test_acc, test_loss, test_confusion = evaluate(config, model, test_iter, test=True) print(f'Test Loss: {test_loss:.2f}, Test Acc: {test_acc:.2%}') print(test_confusion) elapsed_time = get_time_difference(start_time) print(f"Time usage: {elapsed_time}") def evaluate(config, model, data_iter, test=False): """Evaluate the model on a given dataset.""" model.eval() loss_total = 0 predictions, true_values = [], [] with torch.no_grad(): for texts, labels in data_iter: texts, labels = texts.float().to(config.device), labels.long().to(config.device) outputs = model(texts) loss = F.cross_entropy(outputs, labels) loss_total += loss.item() predictions.extend(torch.max(outputs.data, 1)[1].tolist()) true_values.extend(labels.data.tolist()) acc = get_accuracy(true_values, predictions) if test: confusion = metrics.confusion_matrix(true_values, predictions) return acc, loss_total / len(data_iter), confusion return acc, loss_total / len(data_iter) def get_accuracy(y_true, y_pred): """Calculate accuracy.""" return metrics.accuracy_score(y_true, y_pred)