a877aed45f
Change-Id: I16cd7730c1e0732253ac52f51010f6b813295aa7
94 lines
1.9 KiB
Python
94 lines
1.9 KiB
Python
"""
|
|
Author: Weisen Pan
|
|
Date: 2023-10-24
|
|
"""
|
|
import torch
|
|
|
|
class BaseConfig:
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
n_feat = 34
|
|
num_classes = 3
|
|
batch_size = 500
|
|
|
|
def __init__(self, **kwargs):
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|
|
|
|
|
|
class DAGTransformerConfig(BaseConfig):
|
|
model_name = 'DAGTransformer'
|
|
dropout = 0.3
|
|
num_epochs = 500
|
|
num_task = 7
|
|
learning_rate = 1e-4
|
|
hidden_dim = 1024
|
|
num_head = 8
|
|
num_encoder = 6
|
|
d_k = 512
|
|
res_num_layer = 4
|
|
structure = True
|
|
|
|
|
|
class CNNConfig(BaseConfig):
|
|
model_name = 'CNN'
|
|
num_task = 7
|
|
outdim = 512
|
|
num_epochs = 3000
|
|
pooldim = 3
|
|
dropout = 0.3
|
|
learning_rate = 1e-3
|
|
|
|
|
|
class LSTMConfig(BaseConfig):
|
|
model_name = 'LSTM'
|
|
num_task = 7
|
|
learning_rate = 1e-3
|
|
num_epochs = 500
|
|
num_layers = 6
|
|
hidden = 1024
|
|
dropout = 0.5
|
|
pooldim = 3
|
|
|
|
|
|
class GCNConfig(BaseConfig):
|
|
model_name = 'GCN'
|
|
dropout = 0.5
|
|
num_epochs = 15000
|
|
learning_rate = 5e-3
|
|
|
|
|
|
class VanillaTransformerConfig(BaseConfig):
|
|
model_name = 'VanillaTransformer'
|
|
dropout = 0.5
|
|
num_epochs = 100
|
|
num_task = 7
|
|
learning_rate = 1e-4
|
|
hidden = 1024
|
|
num_head = 2
|
|
num_encoder = 6
|
|
|
|
|
|
def select_model_exp1():
|
|
return DAGTransformerConfig()
|
|
|
|
|
|
def select_model_exp2(model_name):
|
|
configs = {
|
|
'DAGTransformer': DAGTransformerConfig(num_epochs=100),
|
|
'CNN': CNNConfig(num_epochs=100, learning_rate=1e-4),
|
|
'LSTM': LSTMConfig(learning_rate=1e-4, num_epochs=100),
|
|
'VanillaTransformer': VanillaTransformerConfig()
|
|
}
|
|
return configs.get(model_name, None)
|
|
|
|
|
|
def select_model_exp3(model_name):
|
|
configs = {
|
|
'DAGTransformer': DAGTransformerConfig(),
|
|
'GCN': GCNConfig(),
|
|
'CNN': CNNConfig(),
|
|
'LSTM': LSTMConfig(),
|
|
'VanillaTransformer': VanillaTransformerConfig()
|
|
}
|
|
return configs.get(model_name, None)
|