1import numbers 2import warnings 3from collections import namedtuple 4from typing import List, Tuple 5 6import torch 7import torch.jit as jit 8import torch.nn as nn 9from torch import Tensor 10from torch.nn import Parameter 11 12 13""" 14Some helper classes for writing custom TorchScript LSTMs. 15 16Goals: 17- Classes are easy to read, use, and extend 18- Performance of custom LSTMs approach fused-kernel-levels of speed. 19 20A few notes about features we could add to clean up the below code: 21- Support enumerate with nn.ModuleList: 22 https://github.com/pytorch/pytorch/issues/14471 23- Support enumerate/zip with lists: 24 https://github.com/pytorch/pytorch/issues/15952 25- Support overriding of class methods: 26 https://github.com/pytorch/pytorch/issues/10733 27- Support passing around user-defined namedtuple types for readability 28- Support slicing w/ range. It enables reversing lists easily. 29 https://github.com/pytorch/pytorch/issues/10774 30- Multiline type annotations. List[List[Tuple[Tensor,Tensor]]] is verbose 31 https://github.com/pytorch/pytorch/pull/14922 32""" 33 34 35def script_lstm( 36 input_size, 37 hidden_size, 38 num_layers, 39 bias=True, 40 batch_first=False, 41 dropout=False, 42 bidirectional=False, 43): 44 """Returns a ScriptModule that mimics a PyTorch native LSTM.""" 45 46 # The following are not implemented. 47 assert bias 48 assert not batch_first 49 50 if bidirectional: 51 stack_type = StackedLSTM2 52 layer_type = BidirLSTMLayer 53 dirs = 2 54 elif dropout: 55 stack_type = StackedLSTMWithDropout 56 layer_type = LSTMLayer 57 dirs = 1 58 else: 59 stack_type = StackedLSTM 60 layer_type = LSTMLayer 61 dirs = 1 62 63 return stack_type( 64 num_layers, 65 layer_type, 66 first_layer_args=[LSTMCell, input_size, hidden_size], 67 other_layer_args=[LSTMCell, hidden_size * dirs, hidden_size], 68 ) 69 70 71def script_lnlstm( 72 input_size, 73 hidden_size, 74 num_layers, 75 bias=True, 76 batch_first=False, 77 dropout=False, 78 bidirectional=False, 79 decompose_layernorm=False, 80): 81 """Returns a ScriptModule that mimics a PyTorch native LSTM.""" 82 83 # The following are not implemented. 84 assert bias 85 assert not batch_first 86 assert not dropout 87 88 if bidirectional: 89 stack_type = StackedLSTM2 90 layer_type = BidirLSTMLayer 91 dirs = 2 92 else: 93 stack_type = StackedLSTM 94 layer_type = LSTMLayer 95 dirs = 1 96 97 return stack_type( 98 num_layers, 99 layer_type, 100 first_layer_args=[ 101 LayerNormLSTMCell, 102 input_size, 103 hidden_size, 104 decompose_layernorm, 105 ], 106 other_layer_args=[ 107 LayerNormLSTMCell, 108 hidden_size * dirs, 109 hidden_size, 110 decompose_layernorm, 111 ], 112 ) 113 114 115LSTMState = namedtuple("LSTMState", ["hx", "cx"]) 116 117 118def reverse(lst: List[Tensor]) -> List[Tensor]: 119 return lst[::-1] 120 121 122class LSTMCell(jit.ScriptModule): 123 def __init__(self, input_size, hidden_size): 124 super().__init__() 125 self.input_size = input_size 126 self.hidden_size = hidden_size 127 self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) 128 self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) 129 self.bias_ih = Parameter(torch.randn(4 * hidden_size)) 130 self.bias_hh = Parameter(torch.randn(4 * hidden_size)) 131 132 @jit.script_method 133 def forward( 134 self, input: Tensor, state: Tuple[Tensor, Tensor] 135 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 136 hx, cx = state 137 gates = ( 138 torch.mm(input, self.weight_ih.t()) 139 + self.bias_ih 140 + torch.mm(hx, self.weight_hh.t()) 141 + self.bias_hh 142 ) 143 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 144 145 ingate = torch.sigmoid(ingate) 146 forgetgate = torch.sigmoid(forgetgate) 147 cellgate = torch.tanh(cellgate) 148 outgate = torch.sigmoid(outgate) 149 150 cy = (forgetgate * cx) + (ingate * cellgate) 151 hy = outgate * torch.tanh(cy) 152 153 return hy, (hy, cy) 154 155 156class LayerNorm(jit.ScriptModule): 157 def __init__(self, normalized_shape): 158 super().__init__() 159 if isinstance(normalized_shape, numbers.Integral): 160 normalized_shape = (normalized_shape,) 161 normalized_shape = torch.Size(normalized_shape) 162 163 # XXX: This is true for our LSTM / NLP use case and helps simplify code 164 assert len(normalized_shape) == 1 165 166 self.weight = Parameter(torch.ones(normalized_shape)) 167 self.bias = Parameter(torch.zeros(normalized_shape)) 168 self.normalized_shape = normalized_shape 169 170 @jit.script_method 171 def compute_layernorm_stats(self, input): 172 mu = input.mean(-1, keepdim=True) 173 sigma = input.std(-1, keepdim=True, unbiased=False) 174 return mu, sigma 175 176 @jit.script_method 177 def forward(self, input): 178 mu, sigma = self.compute_layernorm_stats(input) 179 return (input - mu) / sigma * self.weight + self.bias 180 181 182class LayerNormLSTMCell(jit.ScriptModule): 183 def __init__(self, input_size, hidden_size, decompose_layernorm=False): 184 super().__init__() 185 self.input_size = input_size 186 self.hidden_size = hidden_size 187 self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) 188 self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) 189 # The layernorms provide learnable biases 190 191 if decompose_layernorm: 192 ln = LayerNorm 193 else: 194 ln = nn.LayerNorm 195 196 self.layernorm_i = ln(4 * hidden_size) 197 self.layernorm_h = ln(4 * hidden_size) 198 self.layernorm_c = ln(hidden_size) 199 200 @jit.script_method 201 def forward( 202 self, input: Tensor, state: Tuple[Tensor, Tensor] 203 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 204 hx, cx = state 205 igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) 206 hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) 207 gates = igates + hgates 208 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 209 210 ingate = torch.sigmoid(ingate) 211 forgetgate = torch.sigmoid(forgetgate) 212 cellgate = torch.tanh(cellgate) 213 outgate = torch.sigmoid(outgate) 214 215 cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate)) 216 hy = outgate * torch.tanh(cy) 217 218 return hy, (hy, cy) 219 220 221class LSTMLayer(jit.ScriptModule): 222 def __init__(self, cell, *cell_args): 223 super().__init__() 224 self.cell = cell(*cell_args) 225 226 @jit.script_method 227 def forward( 228 self, input: Tensor, state: Tuple[Tensor, Tensor] 229 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 230 inputs = input.unbind(0) 231 outputs = torch.jit.annotate(List[Tensor], []) 232 for i in range(len(inputs)): 233 out, state = self.cell(inputs[i], state) 234 outputs += [out] 235 return torch.stack(outputs), state 236 237 238class ReverseLSTMLayer(jit.ScriptModule): 239 def __init__(self, cell, *cell_args): 240 super().__init__() 241 self.cell = cell(*cell_args) 242 243 @jit.script_method 244 def forward( 245 self, input: Tensor, state: Tuple[Tensor, Tensor] 246 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 247 inputs = reverse(input.unbind(0)) 248 outputs = jit.annotate(List[Tensor], []) 249 for i in range(len(inputs)): 250 out, state = self.cell(inputs[i], state) 251 outputs += [out] 252 return torch.stack(reverse(outputs)), state 253 254 255class BidirLSTMLayer(jit.ScriptModule): 256 __constants__ = ["directions"] 257 258 def __init__(self, cell, *cell_args): 259 super().__init__() 260 self.directions = nn.ModuleList( 261 [ 262 LSTMLayer(cell, *cell_args), 263 ReverseLSTMLayer(cell, *cell_args), 264 ] 265 ) 266 267 @jit.script_method 268 def forward( 269 self, input: Tensor, states: List[Tuple[Tensor, Tensor]] 270 ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: 271 # List[LSTMState]: [forward LSTMState, backward LSTMState] 272 outputs = jit.annotate(List[Tensor], []) 273 output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) 274 # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 275 i = 0 276 for direction in self.directions: 277 state = states[i] 278 out, out_state = direction(input, state) 279 outputs += [out] 280 output_states += [out_state] 281 i += 1 282 return torch.cat(outputs, -1), output_states 283 284 285def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args): 286 layers = [layer(*first_layer_args)] + [ 287 layer(*other_layer_args) for _ in range(num_layers - 1) 288 ] 289 return nn.ModuleList(layers) 290 291 292class StackedLSTM(jit.ScriptModule): 293 __constants__ = ["layers"] # Necessary for iterating through self.layers 294 295 def __init__(self, num_layers, layer, first_layer_args, other_layer_args): 296 super().__init__() 297 self.layers = init_stacked_lstm( 298 num_layers, layer, first_layer_args, other_layer_args 299 ) 300 301 @jit.script_method 302 def forward( 303 self, input: Tensor, states: List[Tuple[Tensor, Tensor]] 304 ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: 305 # List[LSTMState]: One state per layer 306 output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) 307 output = input 308 # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 309 i = 0 310 for rnn_layer in self.layers: 311 state = states[i] 312 output, out_state = rnn_layer(output, state) 313 output_states += [out_state] 314 i += 1 315 return output, output_states 316 317 318# Differs from StackedLSTM in that its forward method takes 319# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM 320# except we don't support overriding script methods. 321# https://github.com/pytorch/pytorch/issues/10733 322class StackedLSTM2(jit.ScriptModule): 323 __constants__ = ["layers"] # Necessary for iterating through self.layers 324 325 def __init__(self, num_layers, layer, first_layer_args, other_layer_args): 326 super().__init__() 327 self.layers = init_stacked_lstm( 328 num_layers, layer, first_layer_args, other_layer_args 329 ) 330 331 @jit.script_method 332 def forward( 333 self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]] 334 ) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]: 335 # List[List[LSTMState]]: The outer list is for layers, 336 # inner list is for directions. 337 output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) 338 output = input 339 # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 340 i = 0 341 for rnn_layer in self.layers: 342 state = states[i] 343 output, out_state = rnn_layer(output, state) 344 output_states += [out_state] 345 i += 1 346 return output, output_states 347 348 349class StackedLSTMWithDropout(jit.ScriptModule): 350 # Necessary for iterating through self.layers and dropout support 351 __constants__ = ["layers", "num_layers"] 352 353 def __init__(self, num_layers, layer, first_layer_args, other_layer_args): 354 super().__init__() 355 self.layers = init_stacked_lstm( 356 num_layers, layer, first_layer_args, other_layer_args 357 ) 358 # Introduces a Dropout layer on the outputs of each LSTM layer except 359 # the last layer, with dropout probability = 0.4. 360 self.num_layers = num_layers 361 362 if num_layers == 1: 363 warnings.warn( 364 "dropout lstm adds dropout layers after all but last " 365 "recurrent layer, it expects num_layers greater than " 366 "1, but got num_layers = 1" 367 ) 368 369 self.dropout_layer = nn.Dropout(0.4) 370 371 @jit.script_method 372 def forward( 373 self, input: Tensor, states: List[Tuple[Tensor, Tensor]] 374 ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: 375 # List[LSTMState]: One state per layer 376 output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) 377 output = input 378 # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 379 i = 0 380 for rnn_layer in self.layers: 381 state = states[i] 382 output, out_state = rnn_layer(output, state) 383 # Apply the dropout layer except the last layer 384 if i < self.num_layers - 1: 385 output = self.dropout_layer(output) 386 output_states += [out_state] 387 i += 1 388 return output, output_states 389 390 391def flatten_states(states): 392 states = list(zip(*states)) 393 assert len(states) == 2 394 return [torch.stack(state) for state in states] 395 396 397def double_flatten_states(states): 398 # XXX: Can probably write this in a nicer way 399 states = flatten_states([flatten_states(inner) for inner in states]) 400 return [hidden.view([-1] + list(hidden.shape[2:])) for hidden in states] 401 402 403def test_script_rnn_layer(seq_len, batch, input_size, hidden_size): 404 inp = torch.randn(seq_len, batch, input_size) 405 state = LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) 406 rnn = LSTMLayer(LSTMCell, input_size, hidden_size) 407 out, out_state = rnn(inp, state) 408 409 # Control: pytorch native LSTM 410 lstm = nn.LSTM(input_size, hidden_size, 1) 411 lstm_state = LSTMState(state.hx.unsqueeze(0), state.cx.unsqueeze(0)) 412 for lstm_param, custom_param in zip(lstm.all_weights[0], rnn.parameters()): 413 assert lstm_param.shape == custom_param.shape 414 with torch.no_grad(): 415 lstm_param.copy_(custom_param) 416 lstm_out, lstm_out_state = lstm(inp, lstm_state) 417 418 assert (out - lstm_out).abs().max() < 1e-5 419 assert (out_state[0] - lstm_out_state[0]).abs().max() < 1e-5 420 assert (out_state[1] - lstm_out_state[1]).abs().max() < 1e-5 421 422 423def test_script_stacked_rnn(seq_len, batch, input_size, hidden_size, num_layers): 424 inp = torch.randn(seq_len, batch, input_size) 425 states = [ 426 LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) 427 for _ in range(num_layers) 428 ] 429 rnn = script_lstm(input_size, hidden_size, num_layers) 430 out, out_state = rnn(inp, states) 431 custom_state = flatten_states(out_state) 432 433 # Control: pytorch native LSTM 434 lstm = nn.LSTM(input_size, hidden_size, num_layers) 435 lstm_state = flatten_states(states) 436 for layer in range(num_layers): 437 custom_params = list(rnn.parameters())[4 * layer : 4 * (layer + 1)] 438 for lstm_param, custom_param in zip(lstm.all_weights[layer], custom_params): 439 assert lstm_param.shape == custom_param.shape 440 with torch.no_grad(): 441 lstm_param.copy_(custom_param) 442 lstm_out, lstm_out_state = lstm(inp, lstm_state) 443 444 assert (out - lstm_out).abs().max() < 1e-5 445 assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5 446 assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5 447 448 449def test_script_stacked_bidir_rnn(seq_len, batch, input_size, hidden_size, num_layers): 450 inp = torch.randn(seq_len, batch, input_size) 451 states = [ 452 [ 453 LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) 454 for _ in range(2) 455 ] 456 for _ in range(num_layers) 457 ] 458 rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True) 459 out, out_state = rnn(inp, states) 460 custom_state = double_flatten_states(out_state) 461 462 # Control: pytorch native LSTM 463 lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True) 464 lstm_state = double_flatten_states(states) 465 for layer in range(num_layers): 466 for direct in range(2): 467 index = 2 * layer + direct 468 custom_params = list(rnn.parameters())[4 * index : 4 * index + 4] 469 for lstm_param, custom_param in zip(lstm.all_weights[index], custom_params): 470 assert lstm_param.shape == custom_param.shape 471 with torch.no_grad(): 472 lstm_param.copy_(custom_param) 473 lstm_out, lstm_out_state = lstm(inp, lstm_state) 474 475 assert (out - lstm_out).abs().max() < 1e-5 476 assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5 477 assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5 478 479 480def test_script_stacked_lstm_dropout( 481 seq_len, batch, input_size, hidden_size, num_layers 482): 483 inp = torch.randn(seq_len, batch, input_size) 484 states = [ 485 LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) 486 for _ in range(num_layers) 487 ] 488 rnn = script_lstm(input_size, hidden_size, num_layers, dropout=True) 489 490 # just a smoke test 491 out, out_state = rnn(inp, states) 492 493 494def test_script_stacked_lnlstm(seq_len, batch, input_size, hidden_size, num_layers): 495 inp = torch.randn(seq_len, batch, input_size) 496 states = [ 497 LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) 498 for _ in range(num_layers) 499 ] 500 rnn = script_lnlstm(input_size, hidden_size, num_layers) 501 502 # just a smoke test 503 out, out_state = rnn(inp, states) 504 505 506test_script_rnn_layer(5, 2, 3, 7) 507test_script_stacked_rnn(5, 2, 3, 7, 4) 508test_script_stacked_bidir_rnn(5, 2, 3, 7, 4) 509test_script_stacked_lstm_dropout(5, 2, 3, 7, 4) 510test_script_stacked_lnlstm(5, 2, 3, 7, 4) 511