1*da0073e9SAndroid Build Coastguard Workerimport numbers 2*da0073e9SAndroid Build Coastguard Workerimport warnings 3*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 4*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch.jit as jit 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 9*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 10*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Parameter 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker""" 14*da0073e9SAndroid Build Coastguard WorkerSome helper classes for writing custom TorchScript LSTMs. 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard WorkerGoals: 17*da0073e9SAndroid Build Coastguard Worker- Classes are easy to read, use, and extend 18*da0073e9SAndroid Build Coastguard Worker- Performance of custom LSTMs approach fused-kernel-levels of speed. 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard WorkerA few notes about features we could add to clean up the below code: 21*da0073e9SAndroid Build Coastguard Worker- Support enumerate with nn.ModuleList: 22*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/14471 23*da0073e9SAndroid Build Coastguard Worker- Support enumerate/zip with lists: 24*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/15952 25*da0073e9SAndroid Build Coastguard Worker- Support overriding of class methods: 26*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/10733 27*da0073e9SAndroid Build Coastguard Worker- Support passing around user-defined namedtuple types for readability 28*da0073e9SAndroid Build Coastguard Worker- Support slicing w/ range. It enables reversing lists easily. 29*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/10774 30*da0073e9SAndroid Build Coastguard Worker- Multiline type annotations. List[List[Tuple[Tensor,Tensor]]] is verbose 31*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/pull/14922 32*da0073e9SAndroid Build Coastguard Worker""" 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerdef script_lstm( 36*da0073e9SAndroid Build Coastguard Worker input_size, 37*da0073e9SAndroid Build Coastguard Worker hidden_size, 38*da0073e9SAndroid Build Coastguard Worker num_layers, 39*da0073e9SAndroid Build Coastguard Worker bias=True, 40*da0073e9SAndroid Build Coastguard Worker batch_first=False, 41*da0073e9SAndroid Build Coastguard Worker dropout=False, 42*da0073e9SAndroid Build Coastguard Worker bidirectional=False, 43*da0073e9SAndroid Build Coastguard Worker): 44*da0073e9SAndroid Build Coastguard Worker """Returns a ScriptModule that mimics a PyTorch native LSTM.""" 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker # The following are not implemented. 47*da0073e9SAndroid Build Coastguard Worker assert bias 48*da0073e9SAndroid Build Coastguard Worker assert not batch_first 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker if bidirectional: 51*da0073e9SAndroid Build Coastguard Worker stack_type = StackedLSTM2 52*da0073e9SAndroid Build Coastguard Worker layer_type = BidirLSTMLayer 53*da0073e9SAndroid Build Coastguard Worker dirs = 2 54*da0073e9SAndroid Build Coastguard Worker elif dropout: 55*da0073e9SAndroid Build Coastguard Worker stack_type = StackedLSTMWithDropout 56*da0073e9SAndroid Build Coastguard Worker layer_type = LSTMLayer 57*da0073e9SAndroid Build Coastguard Worker dirs = 1 58*da0073e9SAndroid Build Coastguard Worker else: 59*da0073e9SAndroid Build Coastguard Worker stack_type = StackedLSTM 60*da0073e9SAndroid Build Coastguard Worker layer_type = LSTMLayer 61*da0073e9SAndroid Build Coastguard Worker dirs = 1 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker return stack_type( 64*da0073e9SAndroid Build Coastguard Worker num_layers, 65*da0073e9SAndroid Build Coastguard Worker layer_type, 66*da0073e9SAndroid Build Coastguard Worker first_layer_args=[LSTMCell, input_size, hidden_size], 67*da0073e9SAndroid Build Coastguard Worker other_layer_args=[LSTMCell, hidden_size * dirs, hidden_size], 68*da0073e9SAndroid Build Coastguard Worker ) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerdef script_lnlstm( 72*da0073e9SAndroid Build Coastguard Worker input_size, 73*da0073e9SAndroid Build Coastguard Worker hidden_size, 74*da0073e9SAndroid Build Coastguard Worker num_layers, 75*da0073e9SAndroid Build Coastguard Worker bias=True, 76*da0073e9SAndroid Build Coastguard Worker batch_first=False, 77*da0073e9SAndroid Build Coastguard Worker dropout=False, 78*da0073e9SAndroid Build Coastguard Worker bidirectional=False, 79*da0073e9SAndroid Build Coastguard Worker decompose_layernorm=False, 80*da0073e9SAndroid Build Coastguard Worker): 81*da0073e9SAndroid Build Coastguard Worker """Returns a ScriptModule that mimics a PyTorch native LSTM.""" 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker # The following are not implemented. 84*da0073e9SAndroid Build Coastguard Worker assert bias 85*da0073e9SAndroid Build Coastguard Worker assert not batch_first 86*da0073e9SAndroid Build Coastguard Worker assert not dropout 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker if bidirectional: 89*da0073e9SAndroid Build Coastguard Worker stack_type = StackedLSTM2 90*da0073e9SAndroid Build Coastguard Worker layer_type = BidirLSTMLayer 91*da0073e9SAndroid Build Coastguard Worker dirs = 2 92*da0073e9SAndroid Build Coastguard Worker else: 93*da0073e9SAndroid Build Coastguard Worker stack_type = StackedLSTM 94*da0073e9SAndroid Build Coastguard Worker layer_type = LSTMLayer 95*da0073e9SAndroid Build Coastguard Worker dirs = 1 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker return stack_type( 98*da0073e9SAndroid Build Coastguard Worker num_layers, 99*da0073e9SAndroid Build Coastguard Worker layer_type, 100*da0073e9SAndroid Build Coastguard Worker first_layer_args=[ 101*da0073e9SAndroid Build Coastguard Worker LayerNormLSTMCell, 102*da0073e9SAndroid Build Coastguard Worker input_size, 103*da0073e9SAndroid Build Coastguard Worker hidden_size, 104*da0073e9SAndroid Build Coastguard Worker decompose_layernorm, 105*da0073e9SAndroid Build Coastguard Worker ], 106*da0073e9SAndroid Build Coastguard Worker other_layer_args=[ 107*da0073e9SAndroid Build Coastguard Worker LayerNormLSTMCell, 108*da0073e9SAndroid Build Coastguard Worker hidden_size * dirs, 109*da0073e9SAndroid Build Coastguard Worker hidden_size, 110*da0073e9SAndroid Build Coastguard Worker decompose_layernorm, 111*da0073e9SAndroid Build Coastguard Worker ], 112*da0073e9SAndroid Build Coastguard Worker ) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard WorkerLSTMState = namedtuple("LSTMState", ["hx", "cx"]) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Workerdef reverse(lst: List[Tensor]) -> List[Tensor]: 119*da0073e9SAndroid Build Coastguard Worker return lst[::-1] 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Workerclass LSTMCell(jit.ScriptModule): 123*da0073e9SAndroid Build Coastguard Worker def __init__(self, input_size, hidden_size): 124*da0073e9SAndroid Build Coastguard Worker super().__init__() 125*da0073e9SAndroid Build Coastguard Worker self.input_size = input_size 126*da0073e9SAndroid Build Coastguard Worker self.hidden_size = hidden_size 127*da0073e9SAndroid Build Coastguard Worker self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) 128*da0073e9SAndroid Build Coastguard Worker self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) 129*da0073e9SAndroid Build Coastguard Worker self.bias_ih = Parameter(torch.randn(4 * hidden_size)) 130*da0073e9SAndroid Build Coastguard Worker self.bias_hh = Parameter(torch.randn(4 * hidden_size)) 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker @jit.script_method 133*da0073e9SAndroid Build Coastguard Worker def forward( 134*da0073e9SAndroid Build Coastguard Worker self, input: Tensor, state: Tuple[Tensor, Tensor] 135*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 136*da0073e9SAndroid Build Coastguard Worker hx, cx = state 137*da0073e9SAndroid Build Coastguard Worker gates = ( 138*da0073e9SAndroid Build Coastguard Worker torch.mm(input, self.weight_ih.t()) 139*da0073e9SAndroid Build Coastguard Worker + self.bias_ih 140*da0073e9SAndroid Build Coastguard Worker + torch.mm(hx, self.weight_hh.t()) 141*da0073e9SAndroid Build Coastguard Worker + self.bias_hh 142*da0073e9SAndroid Build Coastguard Worker ) 143*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker ingate = torch.sigmoid(ingate) 146*da0073e9SAndroid Build Coastguard Worker forgetgate = torch.sigmoid(forgetgate) 147*da0073e9SAndroid Build Coastguard Worker cellgate = torch.tanh(cellgate) 148*da0073e9SAndroid Build Coastguard Worker outgate = torch.sigmoid(outgate) 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker cy = (forgetgate * cx) + (ingate * cellgate) 151*da0073e9SAndroid Build Coastguard Worker hy = outgate * torch.tanh(cy) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker return hy, (hy, cy) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Workerclass LayerNorm(jit.ScriptModule): 157*da0073e9SAndroid Build Coastguard Worker def __init__(self, normalized_shape): 158*da0073e9SAndroid Build Coastguard Worker super().__init__() 159*da0073e9SAndroid Build Coastguard Worker if isinstance(normalized_shape, numbers.Integral): 160*da0073e9SAndroid Build Coastguard Worker normalized_shape = (normalized_shape,) 161*da0073e9SAndroid Build Coastguard Worker normalized_shape = torch.Size(normalized_shape) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker # XXX: This is true for our LSTM / NLP use case and helps simplify code 164*da0073e9SAndroid Build Coastguard Worker assert len(normalized_shape) == 1 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker self.weight = Parameter(torch.ones(normalized_shape)) 167*da0073e9SAndroid Build Coastguard Worker self.bias = Parameter(torch.zeros(normalized_shape)) 168*da0073e9SAndroid Build Coastguard Worker self.normalized_shape = normalized_shape 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker @jit.script_method 171*da0073e9SAndroid Build Coastguard Worker def compute_layernorm_stats(self, input): 172*da0073e9SAndroid Build Coastguard Worker mu = input.mean(-1, keepdim=True) 173*da0073e9SAndroid Build Coastguard Worker sigma = input.std(-1, keepdim=True, unbiased=False) 174*da0073e9SAndroid Build Coastguard Worker return mu, sigma 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker @jit.script_method 177*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 178*da0073e9SAndroid Build Coastguard Worker mu, sigma = self.compute_layernorm_stats(input) 179*da0073e9SAndroid Build Coastguard Worker return (input - mu) / sigma * self.weight + self.bias 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Workerclass LayerNormLSTMCell(jit.ScriptModule): 183*da0073e9SAndroid Build Coastguard Worker def __init__(self, input_size, hidden_size, decompose_layernorm=False): 184*da0073e9SAndroid Build Coastguard Worker super().__init__() 185*da0073e9SAndroid Build Coastguard Worker self.input_size = input_size 186*da0073e9SAndroid Build Coastguard Worker self.hidden_size = hidden_size 187*da0073e9SAndroid Build Coastguard Worker self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) 188*da0073e9SAndroid Build Coastguard Worker self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) 189*da0073e9SAndroid Build Coastguard Worker # The layernorms provide learnable biases 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker if decompose_layernorm: 192*da0073e9SAndroid Build Coastguard Worker ln = LayerNorm 193*da0073e9SAndroid Build Coastguard Worker else: 194*da0073e9SAndroid Build Coastguard Worker ln = nn.LayerNorm 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker self.layernorm_i = ln(4 * hidden_size) 197*da0073e9SAndroid Build Coastguard Worker self.layernorm_h = ln(4 * hidden_size) 198*da0073e9SAndroid Build Coastguard Worker self.layernorm_c = ln(hidden_size) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker @jit.script_method 201*da0073e9SAndroid Build Coastguard Worker def forward( 202*da0073e9SAndroid Build Coastguard Worker self, input: Tensor, state: Tuple[Tensor, Tensor] 203*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 204*da0073e9SAndroid Build Coastguard Worker hx, cx = state 205*da0073e9SAndroid Build Coastguard Worker igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) 206*da0073e9SAndroid Build Coastguard Worker hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) 207*da0073e9SAndroid Build Coastguard Worker gates = igates + hgates 208*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker ingate = torch.sigmoid(ingate) 211*da0073e9SAndroid Build Coastguard Worker forgetgate = torch.sigmoid(forgetgate) 212*da0073e9SAndroid Build Coastguard Worker cellgate = torch.tanh(cellgate) 213*da0073e9SAndroid Build Coastguard Worker outgate = torch.sigmoid(outgate) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate)) 216*da0073e9SAndroid Build Coastguard Worker hy = outgate * torch.tanh(cy) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker return hy, (hy, cy) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Workerclass LSTMLayer(jit.ScriptModule): 222*da0073e9SAndroid Build Coastguard Worker def __init__(self, cell, *cell_args): 223*da0073e9SAndroid Build Coastguard Worker super().__init__() 224*da0073e9SAndroid Build Coastguard Worker self.cell = cell(*cell_args) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker @jit.script_method 227*da0073e9SAndroid Build Coastguard Worker def forward( 228*da0073e9SAndroid Build Coastguard Worker self, input: Tensor, state: Tuple[Tensor, Tensor] 229*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 230*da0073e9SAndroid Build Coastguard Worker inputs = input.unbind(0) 231*da0073e9SAndroid Build Coastguard Worker outputs = torch.jit.annotate(List[Tensor], []) 232*da0073e9SAndroid Build Coastguard Worker for i in range(len(inputs)): 233*da0073e9SAndroid Build Coastguard Worker out, state = self.cell(inputs[i], state) 234*da0073e9SAndroid Build Coastguard Worker outputs += [out] 235*da0073e9SAndroid Build Coastguard Worker return torch.stack(outputs), state 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Workerclass ReverseLSTMLayer(jit.ScriptModule): 239*da0073e9SAndroid Build Coastguard Worker def __init__(self, cell, *cell_args): 240*da0073e9SAndroid Build Coastguard Worker super().__init__() 241*da0073e9SAndroid Build Coastguard Worker self.cell = cell(*cell_args) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker @jit.script_method 244*da0073e9SAndroid Build Coastguard Worker def forward( 245*da0073e9SAndroid Build Coastguard Worker self, input: Tensor, state: Tuple[Tensor, Tensor] 246*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 247*da0073e9SAndroid Build Coastguard Worker inputs = reverse(input.unbind(0)) 248*da0073e9SAndroid Build Coastguard Worker outputs = jit.annotate(List[Tensor], []) 249*da0073e9SAndroid Build Coastguard Worker for i in range(len(inputs)): 250*da0073e9SAndroid Build Coastguard Worker out, state = self.cell(inputs[i], state) 251*da0073e9SAndroid Build Coastguard Worker outputs += [out] 252*da0073e9SAndroid Build Coastguard Worker return torch.stack(reverse(outputs)), state 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Workerclass BidirLSTMLayer(jit.ScriptModule): 256*da0073e9SAndroid Build Coastguard Worker __constants__ = ["directions"] 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker def __init__(self, cell, *cell_args): 259*da0073e9SAndroid Build Coastguard Worker super().__init__() 260*da0073e9SAndroid Build Coastguard Worker self.directions = nn.ModuleList( 261*da0073e9SAndroid Build Coastguard Worker [ 262*da0073e9SAndroid Build Coastguard Worker LSTMLayer(cell, *cell_args), 263*da0073e9SAndroid Build Coastguard Worker ReverseLSTMLayer(cell, *cell_args), 264*da0073e9SAndroid Build Coastguard Worker ] 265*da0073e9SAndroid Build Coastguard Worker ) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker @jit.script_method 268*da0073e9SAndroid Build Coastguard Worker def forward( 269*da0073e9SAndroid Build Coastguard Worker self, input: Tensor, states: List[Tuple[Tensor, Tensor]] 270*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: 271*da0073e9SAndroid Build Coastguard Worker # List[LSTMState]: [forward LSTMState, backward LSTMState] 272*da0073e9SAndroid Build Coastguard Worker outputs = jit.annotate(List[Tensor], []) 273*da0073e9SAndroid Build Coastguard Worker output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) 274*da0073e9SAndroid Build Coastguard Worker # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 275*da0073e9SAndroid Build Coastguard Worker i = 0 276*da0073e9SAndroid Build Coastguard Worker for direction in self.directions: 277*da0073e9SAndroid Build Coastguard Worker state = states[i] 278*da0073e9SAndroid Build Coastguard Worker out, out_state = direction(input, state) 279*da0073e9SAndroid Build Coastguard Worker outputs += [out] 280*da0073e9SAndroid Build Coastguard Worker output_states += [out_state] 281*da0073e9SAndroid Build Coastguard Worker i += 1 282*da0073e9SAndroid Build Coastguard Worker return torch.cat(outputs, -1), output_states 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Workerdef init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args): 286*da0073e9SAndroid Build Coastguard Worker layers = [layer(*first_layer_args)] + [ 287*da0073e9SAndroid Build Coastguard Worker layer(*other_layer_args) for _ in range(num_layers - 1) 288*da0073e9SAndroid Build Coastguard Worker ] 289*da0073e9SAndroid Build Coastguard Worker return nn.ModuleList(layers) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Workerclass StackedLSTM(jit.ScriptModule): 293*da0073e9SAndroid Build Coastguard Worker __constants__ = ["layers"] # Necessary for iterating through self.layers 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker def __init__(self, num_layers, layer, first_layer_args, other_layer_args): 296*da0073e9SAndroid Build Coastguard Worker super().__init__() 297*da0073e9SAndroid Build Coastguard Worker self.layers = init_stacked_lstm( 298*da0073e9SAndroid Build Coastguard Worker num_layers, layer, first_layer_args, other_layer_args 299*da0073e9SAndroid Build Coastguard Worker ) 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker @jit.script_method 302*da0073e9SAndroid Build Coastguard Worker def forward( 303*da0073e9SAndroid Build Coastguard Worker self, input: Tensor, states: List[Tuple[Tensor, Tensor]] 304*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: 305*da0073e9SAndroid Build Coastguard Worker # List[LSTMState]: One state per layer 306*da0073e9SAndroid Build Coastguard Worker output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) 307*da0073e9SAndroid Build Coastguard Worker output = input 308*da0073e9SAndroid Build Coastguard Worker # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 309*da0073e9SAndroid Build Coastguard Worker i = 0 310*da0073e9SAndroid Build Coastguard Worker for rnn_layer in self.layers: 311*da0073e9SAndroid Build Coastguard Worker state = states[i] 312*da0073e9SAndroid Build Coastguard Worker output, out_state = rnn_layer(output, state) 313*da0073e9SAndroid Build Coastguard Worker output_states += [out_state] 314*da0073e9SAndroid Build Coastguard Worker i += 1 315*da0073e9SAndroid Build Coastguard Worker return output, output_states 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker# Differs from StackedLSTM in that its forward method takes 319*da0073e9SAndroid Build Coastguard Worker# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM 320*da0073e9SAndroid Build Coastguard Worker# except we don't support overriding script methods. 321*da0073e9SAndroid Build Coastguard Worker# https://github.com/pytorch/pytorch/issues/10733 322*da0073e9SAndroid Build Coastguard Workerclass StackedLSTM2(jit.ScriptModule): 323*da0073e9SAndroid Build Coastguard Worker __constants__ = ["layers"] # Necessary for iterating through self.layers 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker def __init__(self, num_layers, layer, first_layer_args, other_layer_args): 326*da0073e9SAndroid Build Coastguard Worker super().__init__() 327*da0073e9SAndroid Build Coastguard Worker self.layers = init_stacked_lstm( 328*da0073e9SAndroid Build Coastguard Worker num_layers, layer, first_layer_args, other_layer_args 329*da0073e9SAndroid Build Coastguard Worker ) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker @jit.script_method 332*da0073e9SAndroid Build Coastguard Worker def forward( 333*da0073e9SAndroid Build Coastguard Worker self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]] 334*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]: 335*da0073e9SAndroid Build Coastguard Worker # List[List[LSTMState]]: The outer list is for layers, 336*da0073e9SAndroid Build Coastguard Worker # inner list is for directions. 337*da0073e9SAndroid Build Coastguard Worker output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) 338*da0073e9SAndroid Build Coastguard Worker output = input 339*da0073e9SAndroid Build Coastguard Worker # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 340*da0073e9SAndroid Build Coastguard Worker i = 0 341*da0073e9SAndroid Build Coastguard Worker for rnn_layer in self.layers: 342*da0073e9SAndroid Build Coastguard Worker state = states[i] 343*da0073e9SAndroid Build Coastguard Worker output, out_state = rnn_layer(output, state) 344*da0073e9SAndroid Build Coastguard Worker output_states += [out_state] 345*da0073e9SAndroid Build Coastguard Worker i += 1 346*da0073e9SAndroid Build Coastguard Worker return output, output_states 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Workerclass StackedLSTMWithDropout(jit.ScriptModule): 350*da0073e9SAndroid Build Coastguard Worker # Necessary for iterating through self.layers and dropout support 351*da0073e9SAndroid Build Coastguard Worker __constants__ = ["layers", "num_layers"] 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker def __init__(self, num_layers, layer, first_layer_args, other_layer_args): 354*da0073e9SAndroid Build Coastguard Worker super().__init__() 355*da0073e9SAndroid Build Coastguard Worker self.layers = init_stacked_lstm( 356*da0073e9SAndroid Build Coastguard Worker num_layers, layer, first_layer_args, other_layer_args 357*da0073e9SAndroid Build Coastguard Worker ) 358*da0073e9SAndroid Build Coastguard Worker # Introduces a Dropout layer on the outputs of each LSTM layer except 359*da0073e9SAndroid Build Coastguard Worker # the last layer, with dropout probability = 0.4. 360*da0073e9SAndroid Build Coastguard Worker self.num_layers = num_layers 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker if num_layers == 1: 363*da0073e9SAndroid Build Coastguard Worker warnings.warn( 364*da0073e9SAndroid Build Coastguard Worker "dropout lstm adds dropout layers after all but last " 365*da0073e9SAndroid Build Coastguard Worker "recurrent layer, it expects num_layers greater than " 366*da0073e9SAndroid Build Coastguard Worker "1, but got num_layers = 1" 367*da0073e9SAndroid Build Coastguard Worker ) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker self.dropout_layer = nn.Dropout(0.4) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker @jit.script_method 372*da0073e9SAndroid Build Coastguard Worker def forward( 373*da0073e9SAndroid Build Coastguard Worker self, input: Tensor, states: List[Tuple[Tensor, Tensor]] 374*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: 375*da0073e9SAndroid Build Coastguard Worker # List[LSTMState]: One state per layer 376*da0073e9SAndroid Build Coastguard Worker output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) 377*da0073e9SAndroid Build Coastguard Worker output = input 378*da0073e9SAndroid Build Coastguard Worker # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 379*da0073e9SAndroid Build Coastguard Worker i = 0 380*da0073e9SAndroid Build Coastguard Worker for rnn_layer in self.layers: 381*da0073e9SAndroid Build Coastguard Worker state = states[i] 382*da0073e9SAndroid Build Coastguard Worker output, out_state = rnn_layer(output, state) 383*da0073e9SAndroid Build Coastguard Worker # Apply the dropout layer except the last layer 384*da0073e9SAndroid Build Coastguard Worker if i < self.num_layers - 1: 385*da0073e9SAndroid Build Coastguard Worker output = self.dropout_layer(output) 386*da0073e9SAndroid Build Coastguard Worker output_states += [out_state] 387*da0073e9SAndroid Build Coastguard Worker i += 1 388*da0073e9SAndroid Build Coastguard Worker return output, output_states 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Workerdef flatten_states(states): 392*da0073e9SAndroid Build Coastguard Worker states = list(zip(*states)) 393*da0073e9SAndroid Build Coastguard Worker assert len(states) == 2 394*da0073e9SAndroid Build Coastguard Worker return [torch.stack(state) for state in states] 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Workerdef double_flatten_states(states): 398*da0073e9SAndroid Build Coastguard Worker # XXX: Can probably write this in a nicer way 399*da0073e9SAndroid Build Coastguard Worker states = flatten_states([flatten_states(inner) for inner in states]) 400*da0073e9SAndroid Build Coastguard Worker return [hidden.view([-1] + list(hidden.shape[2:])) for hidden in states] 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Workerdef test_script_rnn_layer(seq_len, batch, input_size, hidden_size): 404*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(seq_len, batch, input_size) 405*da0073e9SAndroid Build Coastguard Worker state = LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) 406*da0073e9SAndroid Build Coastguard Worker rnn = LSTMLayer(LSTMCell, input_size, hidden_size) 407*da0073e9SAndroid Build Coastguard Worker out, out_state = rnn(inp, state) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker # Control: pytorch native LSTM 410*da0073e9SAndroid Build Coastguard Worker lstm = nn.LSTM(input_size, hidden_size, 1) 411*da0073e9SAndroid Build Coastguard Worker lstm_state = LSTMState(state.hx.unsqueeze(0), state.cx.unsqueeze(0)) 412*da0073e9SAndroid Build Coastguard Worker for lstm_param, custom_param in zip(lstm.all_weights[0], rnn.parameters()): 413*da0073e9SAndroid Build Coastguard Worker assert lstm_param.shape == custom_param.shape 414*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 415*da0073e9SAndroid Build Coastguard Worker lstm_param.copy_(custom_param) 416*da0073e9SAndroid Build Coastguard Worker lstm_out, lstm_out_state = lstm(inp, lstm_state) 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker assert (out - lstm_out).abs().max() < 1e-5 419*da0073e9SAndroid Build Coastguard Worker assert (out_state[0] - lstm_out_state[0]).abs().max() < 1e-5 420*da0073e9SAndroid Build Coastguard Worker assert (out_state[1] - lstm_out_state[1]).abs().max() < 1e-5 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Workerdef test_script_stacked_rnn(seq_len, batch, input_size, hidden_size, num_layers): 424*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(seq_len, batch, input_size) 425*da0073e9SAndroid Build Coastguard Worker states = [ 426*da0073e9SAndroid Build Coastguard Worker LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) 427*da0073e9SAndroid Build Coastguard Worker for _ in range(num_layers) 428*da0073e9SAndroid Build Coastguard Worker ] 429*da0073e9SAndroid Build Coastguard Worker rnn = script_lstm(input_size, hidden_size, num_layers) 430*da0073e9SAndroid Build Coastguard Worker out, out_state = rnn(inp, states) 431*da0073e9SAndroid Build Coastguard Worker custom_state = flatten_states(out_state) 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker # Control: pytorch native LSTM 434*da0073e9SAndroid Build Coastguard Worker lstm = nn.LSTM(input_size, hidden_size, num_layers) 435*da0073e9SAndroid Build Coastguard Worker lstm_state = flatten_states(states) 436*da0073e9SAndroid Build Coastguard Worker for layer in range(num_layers): 437*da0073e9SAndroid Build Coastguard Worker custom_params = list(rnn.parameters())[4 * layer : 4 * (layer + 1)] 438*da0073e9SAndroid Build Coastguard Worker for lstm_param, custom_param in zip(lstm.all_weights[layer], custom_params): 439*da0073e9SAndroid Build Coastguard Worker assert lstm_param.shape == custom_param.shape 440*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 441*da0073e9SAndroid Build Coastguard Worker lstm_param.copy_(custom_param) 442*da0073e9SAndroid Build Coastguard Worker lstm_out, lstm_out_state = lstm(inp, lstm_state) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker assert (out - lstm_out).abs().max() < 1e-5 445*da0073e9SAndroid Build Coastguard Worker assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5 446*da0073e9SAndroid Build Coastguard Worker assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Workerdef test_script_stacked_bidir_rnn(seq_len, batch, input_size, hidden_size, num_layers): 450*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(seq_len, batch, input_size) 451*da0073e9SAndroid Build Coastguard Worker states = [ 452*da0073e9SAndroid Build Coastguard Worker [ 453*da0073e9SAndroid Build Coastguard Worker LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) 454*da0073e9SAndroid Build Coastguard Worker for _ in range(2) 455*da0073e9SAndroid Build Coastguard Worker ] 456*da0073e9SAndroid Build Coastguard Worker for _ in range(num_layers) 457*da0073e9SAndroid Build Coastguard Worker ] 458*da0073e9SAndroid Build Coastguard Worker rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True) 459*da0073e9SAndroid Build Coastguard Worker out, out_state = rnn(inp, states) 460*da0073e9SAndroid Build Coastguard Worker custom_state = double_flatten_states(out_state) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker # Control: pytorch native LSTM 463*da0073e9SAndroid Build Coastguard Worker lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True) 464*da0073e9SAndroid Build Coastguard Worker lstm_state = double_flatten_states(states) 465*da0073e9SAndroid Build Coastguard Worker for layer in range(num_layers): 466*da0073e9SAndroid Build Coastguard Worker for direct in range(2): 467*da0073e9SAndroid Build Coastguard Worker index = 2 * layer + direct 468*da0073e9SAndroid Build Coastguard Worker custom_params = list(rnn.parameters())[4 * index : 4 * index + 4] 469*da0073e9SAndroid Build Coastguard Worker for lstm_param, custom_param in zip(lstm.all_weights[index], custom_params): 470*da0073e9SAndroid Build Coastguard Worker assert lstm_param.shape == custom_param.shape 471*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 472*da0073e9SAndroid Build Coastguard Worker lstm_param.copy_(custom_param) 473*da0073e9SAndroid Build Coastguard Worker lstm_out, lstm_out_state = lstm(inp, lstm_state) 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker assert (out - lstm_out).abs().max() < 1e-5 476*da0073e9SAndroid Build Coastguard Worker assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5 477*da0073e9SAndroid Build Coastguard Worker assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Workerdef test_script_stacked_lstm_dropout( 481*da0073e9SAndroid Build Coastguard Worker seq_len, batch, input_size, hidden_size, num_layers 482*da0073e9SAndroid Build Coastguard Worker): 483*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(seq_len, batch, input_size) 484*da0073e9SAndroid Build Coastguard Worker states = [ 485*da0073e9SAndroid Build Coastguard Worker LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) 486*da0073e9SAndroid Build Coastguard Worker for _ in range(num_layers) 487*da0073e9SAndroid Build Coastguard Worker ] 488*da0073e9SAndroid Build Coastguard Worker rnn = script_lstm(input_size, hidden_size, num_layers, dropout=True) 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker # just a smoke test 491*da0073e9SAndroid Build Coastguard Worker out, out_state = rnn(inp, states) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Workerdef test_script_stacked_lnlstm(seq_len, batch, input_size, hidden_size, num_layers): 495*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(seq_len, batch, input_size) 496*da0073e9SAndroid Build Coastguard Worker states = [ 497*da0073e9SAndroid Build Coastguard Worker LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) 498*da0073e9SAndroid Build Coastguard Worker for _ in range(num_layers) 499*da0073e9SAndroid Build Coastguard Worker ] 500*da0073e9SAndroid Build Coastguard Worker rnn = script_lnlstm(input_size, hidden_size, num_layers) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker # just a smoke test 503*da0073e9SAndroid Build Coastguard Worker out, out_state = rnn(inp, states) 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Workertest_script_rnn_layer(5, 2, 3, 7) 507*da0073e9SAndroid Build Coastguard Workertest_script_stacked_rnn(5, 2, 3, 7, 4) 508*da0073e9SAndroid Build Coastguard Workertest_script_stacked_bidir_rnn(5, 2, 3, 7, 4) 509*da0073e9SAndroid Build Coastguard Workertest_script_stacked_lstm_dropout(5, 2, 3, 7, 4) 510*da0073e9SAndroid Build Coastguard Workertest_script_stacked_lnlstm(5, 2, 3, 7, 4) 511