""" Author: Weisen Pan Date: 2023-10-24 """ import torch class BaseConfig: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') n_feat = 34 num_classes = 3 batch_size = 500 def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) class DAGTransformerConfig(BaseConfig): model_name = 'DAGTransformer' dropout = 0.3 num_epochs = 500 num_task = 7 learning_rate = 1e-4 hidden_dim = 1024 num_head = 8 num_encoder = 6 d_k = 512 res_num_layer = 4 structure = True class CNNConfig(BaseConfig): model_name = 'CNN' num_task = 7 outdim = 512 num_epochs = 3000 pooldim = 3 dropout = 0.3 learning_rate = 1e-3 class LSTMConfig(BaseConfig): model_name = 'LSTM' num_task = 7 learning_rate = 1e-3 num_epochs = 500 num_layers = 6 hidden = 1024 dropout = 0.5 pooldim = 3 class GCNConfig(BaseConfig): model_name = 'GCN' dropout = 0.5 num_epochs = 15000 learning_rate = 5e-3 class VanillaTransformerConfig(BaseConfig): model_name = 'VanillaTransformer' dropout = 0.5 num_epochs = 100 num_task = 7 learning_rate = 1e-4 hidden = 1024 num_head = 2 num_encoder = 6 def select_model_exp1(): return DAGTransformerConfig() def select_model_exp2(model_name): configs = { 'DAGTransformer': DAGTransformerConfig(num_epochs=100), 'CNN': CNNConfig(num_epochs=100, learning_rate=1e-4), 'LSTM': LSTMConfig(learning_rate=1e-4, num_epochs=100), 'VanillaTransformer': VanillaTransformerConfig() } return configs.get(model_name, None) def select_model_exp3(model_name): configs = { 'DAGTransformer': DAGTransformerConfig(), 'GCN': GCNConfig(), 'CNN': CNNConfig(), 'LSTM': LSTMConfig(), 'VanillaTransformer': VanillaTransformerConfig() } return configs.get(model_name, None)