1from torch import nn 2from torch.nn.utils.rnn import PackedSequence 3 4 5class LstmFlatteningResult(nn.LSTM): 6 def forward(self, input, *fargs, **fkwargs): 7 output, (hidden, cell) = nn.LSTM.forward(self, input, *fargs, **fkwargs) 8 return output, hidden, cell 9 10 11class LstmFlatteningResultWithSeqLength(nn.Module): 12 def __init__(self, input_size, hidden_size, layers, bidirect, dropout, batch_first): 13 super().__init__() 14 15 self.batch_first = batch_first 16 self.inner_model = nn.LSTM( 17 input_size=input_size, 18 hidden_size=hidden_size, 19 num_layers=layers, 20 bidirectional=bidirect, 21 dropout=dropout, 22 batch_first=batch_first, 23 ) 24 25 def forward(self, input: PackedSequence, hx=None): 26 output, (hidden, cell) = self.inner_model.forward(input, hx) 27 return output, hidden, cell 28 29 30class LstmFlatteningResultWithoutSeqLength(nn.Module): 31 def __init__(self, input_size, hidden_size, layers, bidirect, dropout, batch_first): 32 super().__init__() 33 34 self.batch_first = batch_first 35 self.inner_model = nn.LSTM( 36 input_size=input_size, 37 hidden_size=hidden_size, 38 num_layers=layers, 39 bidirectional=bidirect, 40 dropout=dropout, 41 batch_first=batch_first, 42 ) 43 44 def forward(self, input, hx=None): 45 output, (hidden, cell) = self.inner_model.forward(input, hx) 46 return output, hidden, cell 47