"""
Author: Weisen Pan
Date: 2023-10-24
"""
import torch

class BaseConfig:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    n_feat = 34
    num_classes = 3
    batch_size = 500

    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)


class DAGTransformerConfig(BaseConfig):
    model_name = 'DAGTransformer'
    dropout = 0.3
    num_epochs = 500
    num_task = 7
    learning_rate = 1e-4
    hidden_dim = 1024
    num_head = 8
    num_encoder = 6
    d_k = 512
    res_num_layer = 4
    structure = True


class CNNConfig(BaseConfig):
    model_name = 'CNN'
    num_task = 7
    outdim = 512
    num_epochs = 3000
    pooldim = 3
    dropout = 0.3
    learning_rate = 1e-3


class LSTMConfig(BaseConfig):
    model_name = 'LSTM'
    num_task = 7
    learning_rate = 1e-3
    num_epochs = 500
    num_layers = 6
    hidden = 1024
    dropout = 0.5
    pooldim = 3


class GCNConfig(BaseConfig):
    model_name = 'GCN'
    dropout = 0.5
    num_epochs = 15000
    learning_rate = 5e-3


class VanillaTransformerConfig(BaseConfig):
    model_name = 'VanillaTransformer'
    dropout = 0.5
    num_epochs = 100
    num_task = 7
    learning_rate = 1e-4
    hidden = 1024
    num_head = 2
    num_encoder = 6


def select_model_exp1():
    return DAGTransformerConfig()


def select_model_exp2(model_name):
    configs = {
        'DAGTransformer': DAGTransformerConfig(num_epochs=100),
        'CNN': CNNConfig(num_epochs=100, learning_rate=1e-4),
        'LSTM': LSTMConfig(learning_rate=1e-4, num_epochs=100),
        'VanillaTransformer': VanillaTransformerConfig()
    }
    return configs.get(model_name, None)


def select_model_exp3(model_name):
    configs = {
        'DAGTransformer': DAGTransformerConfig(),
        'GCN': GCNConfig(),
        'CNN': CNNConfig(),
        'LSTM': LSTMConfig(),
        'VanillaTransformer': VanillaTransformerConfig()
    }
    return configs.get(model_name, None)