1import torch 2from torch import nn 3import torch.nn.functional as F 4 5class LossGen(nn.Module): 6 def __init__(self, gru1_size=16, gru2_size=16): 7 super(LossGen, self).__init__() 8 9 self.gru1_size = gru1_size 10 self.gru2_size = gru2_size 11 self.dense_in = nn.Linear(2, 8) 12 self.gru1 = nn.GRU(8, self.gru1_size, batch_first=True) 13 self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True) 14 self.dense_out = nn.Linear(self.gru2_size, 1) 15 16 def forward(self, loss, perc, states=None): 17 #print(states) 18 device = loss.device 19 batch_size = loss.size(0) 20 if states is None: 21 gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device) 22 gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device) 23 else: 24 gru1_state = states[0] 25 gru2_state = states[1] 26 x = torch.tanh(self.dense_in(torch.cat([loss, perc], dim=-1))) 27 gru1_out, gru1_state = self.gru1(x, gru1_state) 28 gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) 29 return self.dense_out(gru2_out), [gru1_state, gru2_state] 30