use-case-and-architecture/ai_computing_force_scheduling/select_model.py
Weisen Pan a877aed45f AI-based CFN Traffic Control and Computer Force Scheduling
Change-Id: I16cd7730c1e0732253ac52f51010f6b813295aa7
2023-11-03 00:09:19 -07:00

94 lines
1.9 KiB
Python

"""
Author: Weisen Pan
Date: 2023-10-24
"""
import torch
class BaseConfig:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_feat = 34
num_classes = 3
batch_size = 500
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class DAGTransformerConfig(BaseConfig):
model_name = 'DAGTransformer'
dropout = 0.3
num_epochs = 500
num_task = 7
learning_rate = 1e-4
hidden_dim = 1024
num_head = 8
num_encoder = 6
d_k = 512
res_num_layer = 4
structure = True
class CNNConfig(BaseConfig):
model_name = 'CNN'
num_task = 7
outdim = 512
num_epochs = 3000
pooldim = 3
dropout = 0.3
learning_rate = 1e-3
class LSTMConfig(BaseConfig):
model_name = 'LSTM'
num_task = 7
learning_rate = 1e-3
num_epochs = 500
num_layers = 6
hidden = 1024
dropout = 0.5
pooldim = 3
class GCNConfig(BaseConfig):
model_name = 'GCN'
dropout = 0.5
num_epochs = 15000
learning_rate = 5e-3
class VanillaTransformerConfig(BaseConfig):
model_name = 'VanillaTransformer'
dropout = 0.5
num_epochs = 100
num_task = 7
learning_rate = 1e-4
hidden = 1024
num_head = 2
num_encoder = 6
def select_model_exp1():
return DAGTransformerConfig()
def select_model_exp2(model_name):
configs = {
'DAGTransformer': DAGTransformerConfig(num_epochs=100),
'CNN': CNNConfig(num_epochs=100, learning_rate=1e-4),
'LSTM': LSTMConfig(learning_rate=1e-4, num_epochs=100),
'VanillaTransformer': VanillaTransformerConfig()
}
return configs.get(model_name, None)
def select_model_exp3(model_name):
configs = {
'DAGTransformer': DAGTransformerConfig(),
'GCN': GCNConfig(),
'CNN': CNNConfig(),
'LSTM': LSTMConfig(),
'VanillaTransformer': VanillaTransformerConfig()
}
return configs.get(model_name, None)