"""
Author: Weisen Pan
Date: 2023-10-24
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMModel(nn.Module):
    def __init__(self, config):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(config.n_feat, config.hidden, dropout=config.dropout, num_layers=config.num_layers)
        self.maxpool = nn.MaxPool1d(config.pooldim)
        self.fc = nn.Linear((config.hidden // config.pooldim) * config.num_task, config.num_classes)

    def forward(self, x):
        out = x.permute(1, 0, 2)
        out, _ = self.lstm(out)
        out = out.permute(1, 0, 2)
        out = self.maxpool(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out