xref: /aosp_15_r20/external/pytorch/test/onnx/model_defs/lstm_flattening_result.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from torch import nn
2from torch.nn.utils.rnn import PackedSequence
3
4
5class LstmFlatteningResult(nn.LSTM):
6    def forward(self, input, *fargs, **fkwargs):
7        output, (hidden, cell) = nn.LSTM.forward(self, input, *fargs, **fkwargs)
8        return output, hidden, cell
9
10
11class LstmFlatteningResultWithSeqLength(nn.Module):
12    def __init__(self, input_size, hidden_size, layers, bidirect, dropout, batch_first):
13        super().__init__()
14
15        self.batch_first = batch_first
16        self.inner_model = nn.LSTM(
17            input_size=input_size,
18            hidden_size=hidden_size,
19            num_layers=layers,
20            bidirectional=bidirect,
21            dropout=dropout,
22            batch_first=batch_first,
23        )
24
25    def forward(self, input: PackedSequence, hx=None):
26        output, (hidden, cell) = self.inner_model.forward(input, hx)
27        return output, hidden, cell
28
29
30class LstmFlatteningResultWithoutSeqLength(nn.Module):
31    def __init__(self, input_size, hidden_size, layers, bidirect, dropout, batch_first):
32        super().__init__()
33
34        self.batch_first = batch_first
35        self.inner_model = nn.LSTM(
36            input_size=input_size,
37            hidden_size=hidden_size,
38            num_layers=layers,
39            bidirectional=bidirect,
40            dropout=dropout,
41            batch_first=batch_first,
42        )
43
44    def forward(self, input, hx=None):
45        output, (hidden, cell) = self.inner_model.forward(input, hx)
46        return output, hidden, cell
47