"""
Author: Weisen Pan
Date: 2023-10-24
"""
import argparse
import torch

from preprocess import preprocess_data_exp23_dag, preprocess_data_exp23
from select_model import select_model_exp2
from models.DAG_Transformer import DAGTransformer
from models.CNN import CNNModel
from models.LSTM import LSTMModel
from models.Vanilla_Transformer import VanillaTransformerModel
from train_model_dag import train
from train_model_vanilla import train as train_vanilla

# Argument parsing
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', required=True)  # Choices: DAGTransformer, CNN, LSTM, VanillaTransformer
parser.add_argument('--split', default='Branch060202')  # Choices: Branch090505, Branch080101, Branch060202
opt = parser.parse_args()

valid_models = ['DAGTransformer', 'CNN', 'LSTM', 'VanillaTransformer']
valid_splits = ['Branch090505', 'Branch080101', 'Branch060202']

if opt.model_name not in valid_models:
    raise AssertionError('model should be one of: ' + '/'.join(valid_models))
model_name = opt.model_name

if opt.split not in valid_splits:
    raise AssertionError('split should be one of: ' + '/'.join(valid_splits))
split = opt.split

config = select_model_exp2(model_name)
if model_name == 'DAGTransformer':
    train_data, val_data, test_data = preprocess_data_exp23_dag(split)
else:
    train_data, val_data, test_data = preprocess_data_exp23(split)

# Creating data loaders
loader_args = {'batch_size': config.batch_size, 'num_workers': 2, 'shuffle': False}
train_loader = torch.utils.data.DataLoader(dataset=train_data, **loader_args)
val_loader = torch.utils.data.DataLoader(dataset=val_data, **loader_args)
test_loader = torch.utils.data.DataLoader(dataset=test_data, **loader_args)

if __name__ == '__main__':
    if model_name == 'DAGTransformer':
        model = DAGTransformer(config).to(config.device)
        train(config, model, train_loader, val_loader, test_loader)
    else:
        model_class = {
            'LSTM': LSTMModel,
            'CNN': CNNModel,
            'VanillaTransformer': VanillaTransformerModel
        }[model_name]
        model = model_class(config).to(config.device)
        train_vanilla(config, model, train_loader, val_loader, test_loader)