114 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			114 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
Author: Weisen Pan
 | 
						|
Date: 2023-10-24
 | 
						|
"""
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
import torch.nn.functional as F
 | 
						|
from datetime import timedelta
 | 
						|
from sklearn import metrics
 | 
						|
from scheduler import WarmUpLR, downLR
 | 
						|
 | 
						|
 | 
						|
def get_time_difference(start_time):
 | 
						|
    """Compute the time elapsed since the start_time."""
 | 
						|
    end_time = time.time()
 | 
						|
    elapsed_time = end_time - start_time
 | 
						|
    return timedelta(seconds=int(round(elapsed_time)))
 | 
						|
 | 
						|
 | 
						|
def train(config, model, data):
 | 
						|
    start_time = time.time()
 | 
						|
    model.train()
 | 
						|
    
 | 
						|
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
 | 
						|
    warmup_epoch = config.num_epochs / 2
 | 
						|
    scheduler = downLR(optimizer, (config.num_epochs - warmup_epoch))
 | 
						|
    warmup_scheduler = WarmUpLR(optimizer, warmup_epoch)
 | 
						|
    
 | 
						|
    dev_best_loss, dev_best_acc, test_best_acc = float('inf'), 0.0, 0.0
 | 
						|
    learning_rates = np.zeros((config.num_epochs, 2))
 | 
						|
 | 
						|
    for epoch in range(config.num_epochs):
 | 
						|
        print(f'Epoch [{epoch + 1}/{config.num_epochs}]')
 | 
						|
 | 
						|
        learning_rates[epoch][0] = epoch
 | 
						|
 | 
						|
        if epoch >= warmup_epoch:
 | 
						|
            current_learning_rate = scheduler.get_lr()[0]
 | 
						|
            learning_rates[epoch][1] = current_learning_rate
 | 
						|
        else:
 | 
						|
            current_learning_rate = warmup_scheduler.get_lr()[0]
 | 
						|
            learning_rates[epoch][0] = current_learning_rate
 | 
						|
        print(f"Learning Rate: {current_learning_rate}")
 | 
						|
 | 
						|
        data = data.to(config.device)
 | 
						|
        outputs = model(data)
 | 
						|
        model.zero_grad()
 | 
						|
        
 | 
						|
        loss = F.cross_entropy(outputs[data.train_mask], data.labels[data.train_mask])
 | 
						|
        loss.backward()
 | 
						|
 | 
						|
        optimizer.step()
 | 
						|
        if epoch < warmup_epoch:
 | 
						|
            warmup_scheduler.step()
 | 
						|
        else:
 | 
						|
            scheduler.step()
 | 
						|
 | 
						|
        predictions = torch.max(outputs[data.train_mask], 1)[1]
 | 
						|
        train_acc = get_accuracy(predictions, data.labels[data.train_mask])
 | 
						|
 | 
						|
        dev_acc, dev_loss = evaluate(config, model, data)
 | 
						|
        test_acc, test_loss = test(config, model, data)
 | 
						|
 | 
						|
        if dev_loss < dev_best_loss:
 | 
						|
            dev_best_loss = dev_loss
 | 
						|
            improve_marker = '*'
 | 
						|
        else:
 | 
						|
            improve_marker = ''
 | 
						|
 | 
						|
        if dev_acc > dev_best_acc:
 | 
						|
            dev_best_acc = dev_acc
 | 
						|
            test_best_acc = test_acc
 | 
						|
 | 
						|
        elapsed_time = get_time_difference(start_time)
 | 
						|
        status = (f'Iter: {epoch + 1:>6}, Train Loss: {loss.item():>5.2f}, Train Acc: {train_acc:>6.2%}, '
 | 
						|
                  f'Val Loss: {dev_loss:>5.2f}, Val Acc: {dev_acc:>6.2%}, '
 | 
						|
                  f'Test Loss: {test_loss:>5.2f}, Test Acc: {test_acc:>6.2%}, Time: {elapsed_time} {improve_marker}')
 | 
						|
        print(status)
 | 
						|
        print(f'Best Val Acc: {dev_best_acc}, Best Test Acc: {test_best_acc}')
 | 
						|
 | 
						|
    test(config, model, data, final=True)
 | 
						|
 | 
						|
 | 
						|
def test(config, model, data, final=False):
 | 
						|
    model.eval()
 | 
						|
    with torch.no_grad():
 | 
						|
        outputs = model(data)
 | 
						|
        test_loss = F.cross_entropy(outputs[data.test_mask], data.labels[data.test_mask])
 | 
						|
        predictions = torch.max(outputs[data.test_mask], 1)[1]
 | 
						|
        test_acc = get_accuracy(predictions, data.labels[data.test_mask])
 | 
						|
 | 
						|
        if final:
 | 
						|
            print(f'Test Loss: {test_loss:>5.2f}, Test Acc: {test_acc:>6.2%}')
 | 
						|
            confusion = metrics.confusion_matrix(predictions.cpu().numpy(), data.labels[data.test_mask].cpu().numpy())
 | 
						|
            print('Confusion Matrix:\n', confusion)
 | 
						|
            return test_acc, test_loss, confusion
 | 
						|
 | 
						|
    return test_acc, test_loss
 | 
						|
 | 
						|
 | 
						|
def evaluate(config, model, data):
 | 
						|
    model.eval()
 | 
						|
    with torch.no_grad():
 | 
						|
        outputs = model(data)
 | 
						|
        eval_loss = F.cross_entropy(outputs[data.val_mask], data.labels[data.val_mask])
 | 
						|
        predictions = torch.max(outputs[data.val_mask], 1)[1]
 | 
						|
        eval_acc = get_accuracy(predictions, data.labels[data.val_mask])
 | 
						|
 | 
						|
    return eval_acc, eval_loss
 | 
						|
 | 
						|
 | 
						|
def get_accuracy(predictions, true_labels):
 | 
						|
    return metrics.accuracy_score(predictions.cpu().numpy(), true_labels.cpu().numpy())
 |